hellopahe commited on
Commit
3611e07
1 Parent(s): e9377d8
Files changed (2) hide show
  1. app.py +18 -1
  2. lex_rank_new_model.py +44 -0
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import math, torch, gradio as gr
2
 
3
  from lex_rank import LexRank
 
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
 
7
  # ---===--- instances ---===---
8
  embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
9
  lex = LexRank()
 
10
 
11
 
12
  # 摘要方法
@@ -19,6 +21,16 @@ def extract_handler(content):
19
  return output
20
 
21
 
 
 
 
 
 
 
 
 
 
 
22
  # 相似度检测方法
23
  def similarity_search(queries, doc):
24
  doc_list = doc.split('\n')
@@ -43,10 +55,14 @@ def similarity_search(queries, doc):
43
  # web ui
44
  with gr.Blocks() as app:
45
  gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
46
- with gr.Tab("LexRank"):
47
  text_input_1 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
48
  text_button_1 = gr.Button("生成摘要")
49
  text_output_1 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
 
 
 
 
50
  with gr.Tab("相似度检测"):
51
  with gr.Row():
52
  text_input_query = gr.Textbox(lines=10, label="查询文本")
@@ -55,6 +71,7 @@ with gr.Blocks() as app:
55
  text_output_similarity = gr.Textbox()
56
 
57
  text_button_1.click(extract_handler, inputs=text_input_1, outputs=text_output_1)
 
58
  text_button_similarity.click(similarity_search, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
59
 
60
  app.launch(
 
1
  import math, torch, gradio as gr
2
 
3
  from lex_rank import LexRank
4
+ from lex_rank_new_model import LexRankNewModel
5
  from sentence_transformers import SentenceTransformer, util
6
 
7
 
8
  # ---===--- instances ---===---
9
  embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
10
  lex = LexRank()
11
+ lex_new_model = LexRankNewModel()
12
 
13
 
14
  # 摘要方法
 
21
  return output
22
 
23
 
24
+ # 摘要方法
25
+ def extract_handler_new_model(content):
26
+ summary_length = math.ceil(len(content) / 10)
27
+ sentences = lex_new_model.find_central(content, num=summary_length)
28
+ output = ""
29
+ for index, sentence in enumerate(sentences):
30
+ output += f"{index}: {sentence}\n"
31
+ return output
32
+
33
+
34
  # 相似度检测方法
35
  def similarity_search(queries, doc):
36
  doc_list = doc.split('\n')
 
55
  # web ui
56
  with gr.Blocks() as app:
57
  gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
58
+ with gr.Tab("LexRank-mpnet"):
59
  text_input_1 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
60
  text_button_1 = gr.Button("生成摘要")
61
  text_output_1 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
62
+ with gr.Tab("LexRank-distiluse"):
63
+ text_input_2 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
64
+ text_button_2 = gr.Button("生成摘要")
65
+ text_output_2 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
66
  with gr.Tab("相似度检测"):
67
  with gr.Row():
68
  text_input_query = gr.Textbox(lines=10, label="查询文本")
 
71
  text_output_similarity = gr.Textbox()
72
 
73
  text_button_1.click(extract_handler, inputs=text_input_1, outputs=text_output_1)
74
+ text_button_2.click(extract_handler_new_model, inputs=text_input_2, outputs=text_output_2)
75
  text_button_similarity.click(similarity_search, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
76
 
77
  app.launch(
lex_rank_new_model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy, nltk
2
+ nltk.download('punkt')
3
+
4
+
5
+ from harvesttext import HarvestText
6
+ from lex_rank_util import degree_centrality_scores
7
+ from sentence_transformers import SentenceTransformer, util
8
+
9
+
10
+ class LexRankNewModel(object):
11
+ def __init__(self):
12
+ self.model = SentenceTransformer('distiluse-base-multilingual-cased-v1')
13
+ self.ht = HarvestText()
14
+
15
+ def find_central(self, content: str, num=100):
16
+ if self.contains_chinese(content):
17
+ sentences = self.ht.cut_sentences(content)
18
+ else:
19
+ sentences = nltk.sent_tokenize(content)
20
+ embeddings = self.model.encode(sentences, convert_to_tensor=True).cpu()
21
+
22
+ # Compute the pair-wise cosine similarities
23
+ cos_scores = util.cos_sim(embeddings, embeddings).numpy()
24
+
25
+ # Compute the centrality for each sentence
26
+ centrality_scores = degree_centrality_scores(cos_scores, threshold=None)
27
+
28
+ # We argsort so that the first element is the sentence with the highest score
29
+ most_central_sentence_indices = numpy.argsort(-centrality_scores)
30
+
31
+ # num = 100
32
+ res = []
33
+ for index in most_central_sentence_indices:
34
+ if num < 0:
35
+ break
36
+ res.append(sentences[index])
37
+ num -= len(sentences[index])
38
+ return res
39
+
40
+ def contains_chinese(self, content: str):
41
+ for _char in content:
42
+ if '\u4e00' <= _char <= '\u9fa5':
43
+ return True
44
+ return False