File size: 3,882 Bytes
90f83ff
1870c14
 
 
 
 
90f83ff
 
 
1870c14
e7699c1
 
 
 
1870c14
 
 
 
 
 
 
 
e7699c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1870c14
e7699c1
 
 
1870c14
e7699c1
1870c14
90f83ff
e7699c1
1870c14
 
 
e7699c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1870c14
235b9c1
90f83ff
 
 
 
 
 
1870c14
 
90f83ff
1870c14
 
 
 
 
e7699c1
 
1870c14
 
 
e7699c1
 
1870c14
 
 
 
 
90f83ff
1870c14
e7699c1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy
import torch
import gradio as gr

from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline
from article_extractor.tokenizers_pegasus import PegasusTokenizer
from embed import Embed

import tensorflow as tf

from harvesttext import HarvestText
from sentence_transformers import SentenceTransformer, util
from LexRank import degree_centrality_scores


class SummaryExtractor(object):
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = PegasusForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese').to(self.device)
        self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
        self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device)

    def extract(self, content: str) -> str:
        print(content)
        return str(self.text2text_genr(content, do_sample=False, num_return_sequences=3)[0]["generated_text"])

class LexRank(object):
    def __init__(self):
        self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
        self.ht = HarvestText()
    def find_central(self, content: str):
        sentences = self.ht.cut_sentences(content)
        embeddings = self.model.encode(sentences, convert_to_tensor=True)

        # Compute the pair-wise cosine similarities
        cos_scores = util.cos_sim(embeddings, embeddings).numpy()

        # Compute the centrality for each sentence
        centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

        # We argsort so that the first element is the sentence with the highest score
        most_central_sentence_indices = numpy.argsort(-centrality_scores)
        return most_central_sentence_indices

# ---===--- worker instances ---===---
t_randeng = SummaryExtractor()
embedder = Embed()
lex = LexRank()


def randeng_extract(content):
    sentences = lex.find_central(content)

    num = 500
    ptr = 0
    for index, sentence in enumerate(sentences):
        num -= len(sentence)
        if num < 0 and index > 0:
            ptr = index - 1
            break
        if num < 0 and index == 0:
            ptr = index
            break
    print(">>>")
    for ele in sentences[:ptr]:
        print(ele)
    return t_randeng.extract("".join(sentences[:ptr]))


def similarity_check(inputs: list):
    doc_list = inputs[1].split("\n")
    doc_list.append(inputs[0])
    embedding_list = embedder.encode(doc_list)
    scores = (embedding_list[-1] @ tf.transpose(embedding_list[:-1]))[0].numpy().tolist()
    return numpy.array2string(scores, separator=',')

with gr.Blocks() as app:
    gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
    # with gr.Tab("CamelBell-Chinese-LoRA"):
    #     text_input = gr.Textbox()
    #     text_output = gr.Textbox()
    #     text_button = gr.Button("生成摘要")
    with gr.Tab("Randeng-Pegasus-523M"):
        text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
        text_output_1 = gr.Textbox(label="摘要文本")
        text_button_1 = gr.Button("生成摘要")
    with gr.Tab("相似度检测"):
        with gr.Row():
            text_input_query = gr.Textbox(label="查询文本")
            text_input_doc = gr.Textbox(lines=10, label="逐行输入待比较的文本列表")
        text_button_similarity = gr.Button("对比相似度")
        text_output_similarity = gr.Textbox()

    # text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
    text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
    text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)

app.launch(
    # share=True,
    # debug=True
           )