hellopahe
add lexrank
e7699c1
raw
history blame
No virus
3.88 kB
import numpy
import torch
import gradio as gr
from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline
from article_extractor.tokenizers_pegasus import PegasusTokenizer
from embed import Embed
import tensorflow as tf
from harvesttext import HarvestText
from sentence_transformers import SentenceTransformer, util
from LexRank import degree_centrality_scores
class SummaryExtractor(object):
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = PegasusForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese').to(self.device)
self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device)
def extract(self, content: str) -> str:
print(content)
return str(self.text2text_genr(content, do_sample=False, num_return_sequences=3)[0]["generated_text"])
class LexRank(object):
def __init__(self):
self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
self.ht = HarvestText()
def find_central(self, content: str):
sentences = self.ht.cut_sentences(content)
embeddings = self.model.encode(sentences, convert_to_tensor=True)
# Compute the pair-wise cosine similarities
cos_scores = util.cos_sim(embeddings, embeddings).numpy()
# Compute the centrality for each sentence
centrality_scores = degree_centrality_scores(cos_scores, threshold=None)
# We argsort so that the first element is the sentence with the highest score
most_central_sentence_indices = numpy.argsort(-centrality_scores)
return most_central_sentence_indices
# ---===--- worker instances ---===---
t_randeng = SummaryExtractor()
embedder = Embed()
lex = LexRank()
def randeng_extract(content):
sentences = lex.find_central(content)
num = 500
ptr = 0
for index, sentence in enumerate(sentences):
num -= len(sentence)
if num < 0 and index > 0:
ptr = index - 1
break
if num < 0 and index == 0:
ptr = index
break
print(">>>")
for ele in sentences[:ptr]:
print(ele)
return t_randeng.extract("".join(sentences[:ptr]))
def similarity_check(inputs: list):
doc_list = inputs[1].split("\n")
doc_list.append(inputs[0])
embedding_list = embedder.encode(doc_list)
scores = (embedding_list[-1] @ tf.transpose(embedding_list[:-1]))[0].numpy().tolist()
return numpy.array2string(scores, separator=',')
with gr.Blocks() as app:
gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
# with gr.Tab("CamelBell-Chinese-LoRA"):
# text_input = gr.Textbox()
# text_output = gr.Textbox()
# text_button = gr.Button("生成摘要")
with gr.Tab("Randeng-Pegasus-523M"):
text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
text_output_1 = gr.Textbox(label="摘要文本")
text_button_1 = gr.Button("生成摘要")
with gr.Tab("相似度检测"):
with gr.Row():
text_input_query = gr.Textbox(label="查询文本")
text_input_doc = gr.Textbox(lines=10, label="逐行输入待比较的文本列表")
text_button_similarity = gr.Button("对比相似度")
text_output_similarity = gr.Textbox()
# text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
app.launch(
# share=True,
# debug=True
)