File size: 3,965 Bytes
93e5f33
e24946b
93e5f33
19d5657
 
e7699c1
e0738a2
 
93e5f33
77129d5
e7699c1
19d5657
 
1870c14
 
19d5657
93e5f33
e24946b
4d46ad9
7090141
bbf9596
 
 
e0738a2
1870c14
19d5657
 
3611e07
4d46ad9
19d5657
 
 
 
 
 
 
 
 
4d46ad9
3611e07
 
 
 
 
 
93e5f33
77129d5
 
 
 
 
 
 
 
 
 
 
 
 
 
7090141
77129d5
7090141
77129d5
 
 
93e5f33
1870c14
90f83ff
3611e07
77129d5
1870c14
7b20eab
3611e07
 
 
 
19d5657
 
 
 
1870c14
 
7090141
 
1870c14
 
 
93e5f33
19d5657
 
77129d5
1870c14
e7699c1
93e5f33
5170bd0
93e5f33
e9377d8
 
e7699c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import math, torch, gradio as gr

from lex_rank import LexRank
from lex_rank_distiluse_v1 import LexRankDistiluseV1
from lex_rank_L12 import LexRankL12
from sentence_transformers import SentenceTransformer, util


# ---===--- instances ---===---
embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
lex = LexRank()
lex_distiluse_v1 = LexRankDistiluseV1()
lex_l12 = LexRankL12()


# 摘要方法1
def extract_handler(content):
    summary_length = math.ceil(len(content) / 10)
    sentences = lex.find_central(content)
    output = ""
    for index, sentence in enumerate(sentences):
        output += f"{index}: {sentence}\n"
    return output


# 摘要方法2
def extract_handler_distiluse_v1(content):
    summary_length = math.ceil(len(content) / 10)
    sentences = lex_distiluse_v1.find_central(content)
    output = ""
    for index, sentence in enumerate(sentences):
        output += f"{index}: {sentence}\n"
    return output


# 摘要方法3
def extract_handler_l12(content):
    summary_length = math.ceil(len(content) / 10)
    sentences = lex_l12.find_central(content)
    output = ""
    for index, sentence in enumerate(sentences):
        output += f"{index}: {sentence}\n"
    return output


# 相似度检测方法
def similarity_search(queries, doc):
    doc_list = doc.split('\n')
    query_list = queries.split('\n')

    corpus_embeddings = embedder.encode(doc_list, convert_to_tensor=True)
    top_k = min(5, len(doc_list))
    output = ""
    for query in query_list:
        query_embedding = embedder.encode(query, convert_to_tensor=True)
        # We use cosine-similarity and torch.topk to find the highest 5 scores
        cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
        top_results = torch.topk(cos_scores, k=top_k)
        output += "\n\n======================\n\n"
        output += f"Query: {query}"
        output += "\nTop 5 most similar sentences in corpus:\n"
        for score, idx in zip(top_results[0], top_results[1]):
            output += f"{doc_list[idx]}(Score: {score})\n"
    return output


#  web ui
with gr.Blocks() as app:
    gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
    with gr.Tab("LexRank-mpnet"):
        text_input_1 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
        text_button_1 = gr.Button("生成摘要")
        text_output_1 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
    with gr.Tab("LexRank-distiluse"):
        text_input_2 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
        text_button_2 = gr.Button("生成摘要")
        text_output_2 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
    with gr.Tab("LexRank-MiniLM-L12-v2"):
        text_input_3 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
        text_button_3 = gr.Button("生成摘要")
        text_output_3 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
    with gr.Tab("相似度检测"):
        with gr.Row():
            text_input_query = gr.Textbox(lines=10, label="查询文本")
            text_input_doc = gr.Textbox(lines=20, label="逐行输入待比较的文本列表")
        text_button_similarity = gr.Button("对比相似度")
        text_output_similarity = gr.Textbox()

    text_button_1.click(extract_handler, inputs=text_input_1, outputs=text_output_1)
    text_button_2.click(extract_handler_distiluse_v1, inputs=text_input_2, outputs=text_output_2)
    text_button_3.click(extract_handler_l12, inputs=text_input_3, outputs=text_output_3)
    text_button_similarity.click(similarity_search, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)

app.launch(
    # enable share will generate a temporary public link.
    # share=True,
    # debug=True,
    # auth=("qee", "world"),
    # auth_message="请登陆"
           )