Spaces:
Build error
Build error
import numpy | |
import torch | |
import gradio as gr | |
from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline | |
from article_extractor.tokenizers_pegasus import PegasusTokenizer | |
from embed import Embed | |
import tensorflow as tf | |
from harvesttext import HarvestText | |
from sentence_transformers import SentenceTransformer, util | |
from LexRank import degree_centrality_scores | |
class SummaryExtractor(object): | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = PegasusForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese').to(self.device) | |
self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese") | |
self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device) | |
def extract(self, content: str) -> str: | |
print(content) | |
return str(self.text2text_genr(content, do_sample=False, num_return_sequences=3)[0]["generated_text"]) | |
class LexRank(object): | |
def __init__(self): | |
self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2') | |
self.ht = HarvestText() | |
def find_central(self, content: str): | |
sentences = self.ht.cut_sentences(content) | |
embeddings = self.model.encode(sentences, convert_to_tensor=True) | |
# Compute the pair-wise cosine similarities | |
cos_scores = util.cos_sim(embeddings, embeddings).numpy() | |
# Compute the centrality for each sentence | |
centrality_scores = degree_centrality_scores(cos_scores, threshold=None) | |
# We argsort so that the first element is the sentence with the highest score | |
most_central_sentence_indices = numpy.argsort(-centrality_scores) | |
return most_central_sentence_indices | |
# ---===--- worker instances ---===--- | |
t_randeng = SummaryExtractor() | |
embedder = Embed() | |
lex = LexRank() | |
def randeng_extract(content): | |
sentences = lex.find_central(content) | |
num = 500 | |
ptr = 0 | |
for index, sentence in enumerate(sentences): | |
num -= len(sentence) | |
if num < 0 and index > 0: | |
ptr = index - 1 | |
break | |
if num < 0 and index == 0: | |
ptr = index | |
break | |
print(">>>") | |
for ele in sentences[:ptr]: | |
print(ele) | |
return t_randeng.extract("".join(sentences[:ptr])) | |
def similarity_check(inputs: list): | |
doc_list = inputs[1].split("\n") | |
doc_list.append(inputs[0]) | |
embedding_list = embedder.encode(doc_list) | |
scores = (embedding_list[-1] @ tf.transpose(embedding_list[:-1]))[0].numpy().tolist() | |
return numpy.array2string(scores, separator=',') | |
with gr.Blocks() as app: | |
gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]") | |
# with gr.Tab("CamelBell-Chinese-LoRA"): | |
# text_input = gr.Textbox() | |
# text_output = gr.Textbox() | |
# text_button = gr.Button("生成摘要") | |
with gr.Tab("Randeng-Pegasus-523M"): | |
text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000) | |
text_output_1 = gr.Textbox(label="摘要文本") | |
text_button_1 = gr.Button("生成摘要") | |
with gr.Tab("相似度检测"): | |
with gr.Row(): | |
text_input_query = gr.Textbox(label="查询文本") | |
text_input_doc = gr.Textbox(lines=10, label="逐行输入待比较的文本列表") | |
text_button_similarity = gr.Button("对比相似度") | |
text_output_similarity = gr.Textbox() | |
# text_button.click(tuoling_extract, inputs=text_input, outputs=text_output) | |
text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1) | |
text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity) | |
app.launch( | |
# share=True, | |
# debug=True | |
) | |