[docs]classGloVe:r"""The GloVe model for social event detection that uses GloVe word embeddings to detect events in social media data. .. note:: This detector uses word embeddings to identify events in social media data. The model requires a dataset object with a load_data() method. Parameters ---------- dataset : object The dataset object containing social media data. Must provide load_data() method that returns the raw data. num_clusters : int, optional Number of clusters for KMeans clustering. Default: ``50``. random_state : int, optional Random seed for reproducibility. Default: ``1``. file_path : str, optional Path to save model files. Default: ``'../model/model_saved/GloVe/'``. model : str, optional Path to pre-trained GloVe word vectors file. Default: ``'../model/model_needed/glove.6B.100d.txt'``. """def__init__(self,dataset,num_clusters=50,random_state=1,file_path='../model/model_saved/GloVe/',model='../model/model_needed/glove.6B.100d.txt'):self.dataset=dataset.load_data()self.num_clusters=num_clustersself.random_state=random_stateself.model_path=os.path.join(file_path,'kmeans_model')self.df=Noneself.train_df=Noneself.test_df=Noneself.model=modelself.embeddings_index=self.load_glove_vectors()
[docs]defload_glove_vectors(self):""" Load GloVe pre-trained word vectors. """embeddings_index={}withopen(self.model,'r',encoding='utf8')asf:forlineinf:values=line.split()word=values[0]coefs=np.asarray(values[1:],dtype='float32')embeddings_index[word]=coefsreturnembeddings_index
[docs]defpreprocess(self):""" Data preprocessing: tokenization, stop words removal, etc. """df=self.dataset[['filtered_words','event_id']].copy()df['processed_text']=df['filtered_words'].apply(lambdax:[str(word).lower()forwordinx]ifisinstance(x,list)else[])self.df=dfreturndf
[docs]deftext_to_glove_vector(self,text,embedding_dim=100):""" Convert text to GloVe vector representation. """words=textembedding=np.zeros(embedding_dim)valid_words=0forwordinwords:ifwordinself.embeddings_index:embedding+=self.embeddings_index[word]valid_words+=1ifvalid_words>0:embedding/=valid_wordsreturnembedding
[docs]defcreate_vectors(self,df,text_column):""" Create GloVe vectors for each document. """texts=df[text_column].tolist()vectors=np.array([self.text_to_glove_vector(text)fortextintexts])returnvectors
[docs]defload_model(self):""" Load the KMeans model from a file. """logging.info(f"Loading KMeans model from {self.model_path}...")kmeans_model=KMeans(n_clusters=self.num_clusters,random_state=self.random_state)kmeans_model=kmeans_model.fit(self.train_vectors)# 重新训练模型logging.info("KMeans model loaded successfully.")self.kmeans_model=kmeans_modelreturnkmeans_model
[docs]deffit(self):os.makedirs(os.path.dirname(self.model_path),exist_ok=True)train_df,test_df=train_test_split(self.df,test_size=0.2,random_state=self.random_state)self.train_df=train_dfself.test_df=test_dfself.train_vectors=self.create_vectors(train_df,'processed_text')logging.info("Training KMeans model...")kmeans_model=KMeans(n_clusters=self.num_clusters,random_state=self.random_state)kmeans_model.fit(self.train_vectors)logging.info("KMeans model trained successfully.")# Save the trained model to a filewithopen(self.model_path,'wb')asf:pickle.dump(kmeans_model,f)logging.info(f"KMeans model saved to {self.model_path}")
[docs]defdetection(self):""" Assign clusters to each document. """self.load_model()# Ensure the model is loaded before making detectionsself.test_vectors=self.create_vectors(self.test_df,'processed_text')labels=self.kmeans_model.predict(self.test_vectors)# Get the ground truth labels and predicted labelsground_truths=self.test_df['event_id'].tolist()predicted_labels=labels.tolist()returnground_truths,predicted_labels
[docs]defevaluate(self,ground_truths,predictions):""" Evaluate the model. """# Calculate Normalized Mutual Information (NMI)nmi=metrics.normalized_mutual_info_score(ground_truths,predictions)print(f"Normalized Mutual Information (NMI): {nmi}")# Calculate Adjusted Mutual Information (AMI)ami=metrics.adjusted_mutual_info_score(ground_truths,predictions)print(f"Adjusted Mutual Information (AMI): {ami}")# Calculate Adjusted Rand Index (ARI)ari=metrics.adjusted_rand_score(ground_truths,predictions)print(f"Adjusted Rand Index (ARI): {ari}")