TSA / classification.py
QINGCHE's picture
fix bugs
8ba144e
raw
history blame
2.96 kB
import gensim
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel
import torch
def classify_by_topic(articles, central_topics):
# 计算与每个中心主题的相似度,返回一个矩阵
def compute_similarity(articles, central_topics):
model = AutoModel.from_pretrained("distilbert-base-multilingual-cased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
def sentence_to_vector(sentence, context):
sentence = context[0]+context[1]+sentence*4+context[2]+context[3]
tokens = tokenizer.encode_plus(
sentence, add_special_tokens=True, return_tensors="pt",max_length = 512,truncation=True)
outputs = model(**tokens)
hidden_states = outputs.last_hidden_state
vector = np.squeeze(torch.mean(
hidden_states, dim=1).detach().numpy())
return vector
# 获取一个句子的上下文
def get_context(sentences, index):
if index == 0:
prev_sentence = ""
pprev_sentence = ""
elif index == 1:
prev_sentence = sentences[index-1]
pprev_sentence = ""
else:
prev_sentence = sentences[index-1]
pprev_sentence = sentences[index-2]
if index == len(sentences) - 1:
next_sentence = ""
nnext_sentence = ""
elif index == len(sentences) - 2:
next_sentence = sentences[index+1]
nnext_sentence = ""
else:
next_sentence = sentences[index+1]
nnext_sentence = sentences[index+2]
return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence)
doc_vectors = [sentence_to_vector(sentence, get_context(
articles, i)) for i, sentence in enumerate(articles)]
topic_vectors = [sentence_to_vector(sentence, get_context(
central_topics, i)) for i, sentence in enumerate(central_topics)]
# 计算余弦相似度矩阵
cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors)
return cos_sim_matrix
# 分类文章
def group_by_topic(articles, central_topics, similarity_matrix):
group = []
original_articles = articles.copy()
for article, similarity in zip(original_articles, similarity_matrix):
max_similarity = max(similarity)
max_index = similarity.tolist().index(max_similarity)
group.append((article, central_topics[max_index]))
return group
# 实现分类功能
similarity_matrix = compute_similarity(articles, central_topics)
groups = group_by_topic(articles, central_topics, similarity_matrix)
return groups