hellopahe commited on
Commit
19d5657
1 Parent(s): 657b2a2
app.py CHANGED
@@ -1,17 +1,19 @@
1
  import math, torch, gradio as gr
2
 
3
  from lex_rank import LexRank
4
- from lex_rank_new_model import LexRankNewModel
 
5
  from sentence_transformers import SentenceTransformer, util
6
 
7
 
8
  # ---===--- instances ---===---
9
  embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
10
  lex = LexRank()
11
- lex_new_model = LexRankNewModel()
 
12
 
13
 
14
- # 摘要方法
15
  def extract_handler(content):
16
  summary_length = math.ceil(len(content) / 10)
17
  sentences = lex.find_central(content, num=summary_length)
@@ -21,10 +23,20 @@ def extract_handler(content):
21
  return output
22
 
23
 
24
- # 摘要方法
25
- def extract_handler_new_model(content):
26
  summary_length = math.ceil(len(content) / 10)
27
- sentences = lex_new_model.find_central(content, num=summary_length)
 
 
 
 
 
 
 
 
 
 
28
  output = ""
29
  for index, sentence in enumerate(sentences):
30
  output += f"{index}: {sentence}\n"
@@ -63,6 +75,10 @@ with gr.Blocks() as app:
63
  text_input_2 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
64
  text_button_2 = gr.Button("生成摘要")
65
  text_output_2 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
 
 
 
 
66
  with gr.Tab("相似度检测"):
67
  with gr.Row():
68
  text_input_query = gr.Textbox(lines=10, label="查询文本")
@@ -71,7 +87,8 @@ with gr.Blocks() as app:
71
  text_output_similarity = gr.Textbox()
72
 
73
  text_button_1.click(extract_handler, inputs=text_input_1, outputs=text_output_1)
74
- text_button_2.click(extract_handler_new_model, inputs=text_input_2, outputs=text_output_2)
 
75
  text_button_similarity.click(similarity_search, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
76
 
77
  app.launch(
 
1
  import math, torch, gradio as gr
2
 
3
  from lex_rank import LexRank
4
+ from lex_rank_distiluse_v1 import LexRankDistiluseV1
5
+ from lex_rank_L12 import LexRankL12
6
  from sentence_transformers import SentenceTransformer, util
7
 
8
 
9
  # ---===--- instances ---===---
10
  embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
11
  lex = LexRank()
12
+ lex_distiluse_v1 = LexRankDistiluseV1()
13
+ lex_l12 = LexRankL12()
14
 
15
 
16
+ # 摘要方法1
17
  def extract_handler(content):
18
  summary_length = math.ceil(len(content) / 10)
19
  sentences = lex.find_central(content, num=summary_length)
 
23
  return output
24
 
25
 
26
+ # 摘要方法2
27
+ def extract_handler_distiluse_v1(content):
28
  summary_length = math.ceil(len(content) / 10)
29
+ sentences = lex_distiluse_v1.find_central(content, num=summary_length)
30
+ output = ""
31
+ for index, sentence in enumerate(sentences):
32
+ output += f"{index}: {sentence}\n"
33
+ return output
34
+
35
+
36
+ # 摘要方法3
37
+ def extract_handler_l12(content):
38
+ summary_length = math.ceil(len(content) / 10)
39
+ sentences = lex_l12.find_central(content, num=summary_length)
40
  output = ""
41
  for index, sentence in enumerate(sentences):
42
  output += f"{index}: {sentence}\n"
 
75
  text_input_2 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
76
  text_button_2 = gr.Button("生成摘要")
77
  text_output_2 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
78
+ with gr.Tab("LexRank-MiniLM-L12-v2"):
79
+ text_input_3 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
80
+ text_button_3 = gr.Button("生成摘要")
81
+ text_output_3 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
82
  with gr.Tab("相似度检测"):
83
  with gr.Row():
84
  text_input_query = gr.Textbox(lines=10, label="查询文本")
 
87
  text_output_similarity = gr.Textbox()
88
 
89
  text_button_1.click(extract_handler, inputs=text_input_1, outputs=text_output_1)
90
+ text_button_2.click(extract_handler_distiluse_v1, inputs=text_input_2, outputs=text_output_2)
91
+ text_button_3.click(extract_handler_l12, inputs=text_input_3, outputs=text_output_3)
92
  text_button_similarity.click(similarity_search, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
93
 
94
  app.launch(
lex_rank_L12.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy, nltk
2
+ nltk.download('punkt')
3
+
4
+
5
+ from harvesttext import HarvestText
6
+ from lex_rank_util import degree_centrality_scores
7
+ from sentence_transformers import SentenceTransformer, util
8
+
9
+
10
+ class LexRankL12(object):
11
+ def __init__(self):
12
+ self.model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
13
+ self.ht = HarvestText()
14
+
15
+ def find_central(self, content: str, num=100):
16
+ if self.contains_chinese(content):
17
+ sentences = self.ht.cut_sentences(content)
18
+ else:
19
+ sentences = nltk.sent_tokenize(content)
20
+ embeddings = self.model.encode(sentences, convert_to_tensor=True).cpu()
21
+
22
+ # Compute the pair-wise cosine similarities
23
+ cos_scores = util.cos_sim(embeddings, embeddings).numpy()
24
+
25
+ # Compute the centrality for each sentence
26
+ centrality_scores = degree_centrality_scores(cos_scores, threshold=None)
27
+
28
+ # We argsort so that the first element is the sentence with the highest score
29
+ most_central_sentence_indices = numpy.argsort(-centrality_scores)
30
+
31
+ # num = 100
32
+ res = []
33
+ for index in most_central_sentence_indices:
34
+ if num < 0:
35
+ break
36
+ res.append(sentences[index])
37
+ num -= len(sentences[index])
38
+ return res
39
+
40
+ def contains_chinese(self, content: str):
41
+ for _char in content:
42
+ if '\u4e00' <= _char <= '\u9fa5':
43
+ return True
44
+ return False
lex_rank_new_model.py → lex_rank_distiluse_v1.py RENAMED
@@ -7,7 +7,7 @@ from lex_rank_util import degree_centrality_scores
7
  from sentence_transformers import SentenceTransformer, util
8
 
9
 
10
- class LexRankNewModel(object):
11
  def __init__(self):
12
  self.model = SentenceTransformer('distiluse-base-multilingual-cased-v1')
13
  self.ht = HarvestText()
 
7
  from sentence_transformers import SentenceTransformer, util
8
 
9
 
10
+ class LexRankDistiluseV1(object):
11
  def __init__(self):
12
  self.model = SentenceTransformer('distiluse-base-multilingual-cased-v1')
13
  self.ht = HarvestText()