Spaces:
Build error
Build error
import numpy, nltk | |
nltk.download('punkt') | |
from harvesttext import HarvestText | |
from lex_rank_util import degree_centrality_scores | |
from sentence_transformers import SentenceTransformer, util | |
class LexRank(object): | |
def __init__(self): | |
self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2') | |
self.ht = HarvestText() | |
def find_central(self, content: str, num=10): | |
if self.contains_chinese(content): | |
sentences = self.ht.cut_sentences(content) | |
else: | |
sentences = nltk.sent_tokenize(content) | |
embeddings = self.model.encode(sentences, convert_to_tensor=True).cpu() | |
# Compute the pair-wise cosine similarities | |
cos_scores = util.cos_sim(embeddings, embeddings).numpy() | |
# Compute the centrality for each sentence | |
centrality_scores = degree_centrality_scores(cos_scores, threshold=None) | |
# We argsort so that the first element is the sentence with the highest score | |
most_central_sentence_indices = numpy.argsort(-centrality_scores) | |
# num = 100 | |
res = [] | |
for index in most_central_sentence_indices: | |
if num < 0: | |
break | |
res.append(sentences[index]) | |
num -= 1 | |
return res | |
def contains_chinese(self, content: str): | |
for _char in content: | |
if '\u4e00' <= _char <= '\u9fa5': | |
return True | |
return False | |