Spaces:
Build error
Build error
import torch | |
import gradio as gr | |
from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline | |
from article_extractor.tokenizers_pegasus import PegasusTokenizer | |
class SummaryExtractor(object): | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = PegasusForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese').to(self.device) | |
self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese") | |
self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device) | |
def extract(self, content: str, min=20, max=30) -> str: | |
return str(self.text2text_genr(content, do_sample=False, min_length=min, max_length=max, num_return_sequences=3)[0]["generated_text"]) | |
t_randeng = SummaryExtractor() | |
def randeng_extract(content): | |
return t_randeng.extract(content) | |
def similarity_check(query: str, doc: str): | |
doc_list = doc.split("\n") | |
return "similarity result" | |
with gr.Blocks() as app: | |
gr.Markdown("从下面的标签选择不同的摘要模型, 在左侧输入原文") | |
# with gr.Tab("CamelBell-Chinese-LoRA"): | |
# text_input = gr.Textbox() | |
# text_output = gr.Textbox() | |
# text_button = gr.Button("生成摘要") | |
with gr.Tab("Randeng-Pegasus-523M"): | |
text_input_1 = gr.Textbox() | |
text_output_1 = gr.Textbox() | |
text_button_1 = gr.Button("生成摘要") | |
# with gr.Tab("Flip Image"): | |
# with gr.Row(): | |
# image_input = gr.Image() | |
# image_output = gr.Image() | |
# image_button = gr.Button("Flip") | |
with gr.Tab("相似度检测"): | |
with gr.Row(): | |
text_input_query = gr.Textbox() | |
text_input_doc = gr.Textbox() | |
text_button_similarity = gr.Button("对比相似度") | |
text_output_similarity = gr.Textbox() | |
# text_button.click(tuoling_extract, inputs=text_input, outputs=text_output) | |
text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1) | |
text_button_similarity.click(similarity_check, inputs=text_input_query, outputs=text_input_doc) | |
app.launch() | |