hellopahe commited on
Commit
e7699c1
1 Parent(s): 90f83ff

add lexrank

Browse files
LexRank.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LexRank implementation
3
+ Source: https://github.com/crabcamp/lexrank/tree/dev
4
+ """
5
+
6
+ import numpy as np
7
+ from scipy.sparse.csgraph import connected_components
8
+ from scipy.special import softmax
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def degree_centrality_scores(
14
+ similarity_matrix,
15
+ threshold=None,
16
+ increase_power=True,
17
+ ):
18
+ if not (
19
+ threshold is None
20
+ or isinstance(threshold, float)
21
+ and 0 <= threshold < 1
22
+ ):
23
+ raise ValueError(
24
+ '\'threshold\' should be a floating-point number '
25
+ 'from the interval [0, 1) or None',
26
+ )
27
+
28
+ if threshold is None:
29
+ markov_matrix = create_markov_matrix(similarity_matrix)
30
+
31
+ else:
32
+ markov_matrix = create_markov_matrix_discrete(
33
+ similarity_matrix,
34
+ threshold,
35
+ )
36
+
37
+ scores = stationary_distribution(
38
+ markov_matrix,
39
+ increase_power=increase_power,
40
+ normalized=False,
41
+ )
42
+
43
+ return scores
44
+
45
+
46
+ def _power_method(transition_matrix, increase_power=True, max_iter=10000):
47
+ eigenvector = np.ones(len(transition_matrix))
48
+
49
+ if len(eigenvector) == 1:
50
+ return eigenvector
51
+
52
+ transition = transition_matrix.transpose()
53
+
54
+ for _ in range(max_iter):
55
+ eigenvector_next = np.dot(transition, eigenvector)
56
+
57
+ if np.allclose(eigenvector_next, eigenvector):
58
+ return eigenvector_next
59
+
60
+ eigenvector = eigenvector_next
61
+
62
+ if increase_power:
63
+ transition = np.dot(transition, transition)
64
+
65
+ logger.warning("Maximum number of iterations for power method exceeded without convergence!")
66
+ return eigenvector_next
67
+
68
+
69
+ def connected_nodes(matrix):
70
+ _, labels = connected_components(matrix)
71
+
72
+ groups = []
73
+
74
+ for tag in np.unique(labels):
75
+ group = np.where(labels == tag)[0]
76
+ groups.append(group)
77
+
78
+ return groups
79
+
80
+
81
+ def create_markov_matrix(weights_matrix):
82
+ n_1, n_2 = weights_matrix.shape
83
+ if n_1 != n_2:
84
+ raise ValueError('\'weights_matrix\' should be square')
85
+
86
+ row_sum = weights_matrix.sum(axis=1, keepdims=True)
87
+
88
+ # normalize probability distribution differently if we have negative transition values
89
+ if np.min(weights_matrix) <= 0:
90
+ return softmax(weights_matrix, axis=1)
91
+
92
+ return weights_matrix / row_sum
93
+
94
+
95
+ def create_markov_matrix_discrete(weights_matrix, threshold):
96
+ discrete_weights_matrix = np.zeros(weights_matrix.shape)
97
+ ixs = np.where(weights_matrix >= threshold)
98
+ discrete_weights_matrix[ixs] = 1
99
+
100
+ return create_markov_matrix(discrete_weights_matrix)
101
+
102
+
103
+ def stationary_distribution(
104
+ transition_matrix,
105
+ increase_power=True,
106
+ normalized=True,
107
+ ):
108
+ n_1, n_2 = transition_matrix.shape
109
+ if n_1 != n_2:
110
+ raise ValueError('\'transition_matrix\' should be square')
111
+
112
+ distribution = np.zeros(n_1)
113
+
114
+ grouped_indices = connected_nodes(transition_matrix)
115
+
116
+ for group in grouped_indices:
117
+ t_matrix = transition_matrix[np.ix_(group, group)]
118
+ eigenvector = _power_method(t_matrix, increase_power=increase_power)
119
+ distribution[group] = eigenvector
120
+
121
+ if normalized:
122
+ distribution /= n_1
123
+
124
+ return distribution
app.py CHANGED
@@ -8,6 +8,10 @@ from embed import Embed
8
 
9
  import tensorflow as tf
10
 
 
 
 
 
11
 
12
  class SummaryExtractor(object):
13
  def __init__(self):
@@ -16,16 +20,51 @@ class SummaryExtractor(object):
16
  self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
17
  self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device)
18
 
19
- def extract(self, content: str, min=20, max=30) -> str:
20
- return str(self.text2text_genr(content, do_sample=False, min_length=min, max_length=max, num_return_sequences=3)[0]["generated_text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
22
 
 
23
  t_randeng = SummaryExtractor()
24
  embedder = Embed()
 
25
 
26
 
27
  def randeng_extract(content):
28
- return t_randeng.extract(content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def similarity_check(inputs: list):
@@ -42,13 +81,13 @@ with gr.Blocks() as app:
42
  # text_output = gr.Textbox()
43
  # text_button = gr.Button("生成摘要")
44
  with gr.Tab("Randeng-Pegasus-523M"):
45
- text_input_1 = gr.Textbox()
46
- text_output_1 = gr.Textbox()
47
  text_button_1 = gr.Button("生成摘要")
48
  with gr.Tab("相似度检测"):
49
  with gr.Row():
50
- text_input_query = gr.Textbox()
51
- text_input_doc = gr.Textbox()
52
  text_button_similarity = gr.Button("对比相似度")
53
  text_output_similarity = gr.Textbox()
54
 
@@ -56,4 +95,7 @@ with gr.Blocks() as app:
56
  text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
57
  text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
58
 
59
- app.launch()
 
 
 
 
8
 
9
  import tensorflow as tf
10
 
11
+ from harvesttext import HarvestText
12
+ from sentence_transformers import SentenceTransformer, util
13
+ from LexRank import degree_centrality_scores
14
+
15
 
16
  class SummaryExtractor(object):
17
  def __init__(self):
 
20
  self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
21
  self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device)
22
 
23
+ def extract(self, content: str) -> str:
24
+ print(content)
25
+ return str(self.text2text_genr(content, do_sample=False, num_return_sequences=3)[0]["generated_text"])
26
+
27
+ class LexRank(object):
28
+ def __init__(self):
29
+ self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
30
+ self.ht = HarvestText()
31
+ def find_central(self, content: str):
32
+ sentences = self.ht.cut_sentences(content)
33
+ embeddings = self.model.encode(sentences, convert_to_tensor=True)
34
+
35
+ # Compute the pair-wise cosine similarities
36
+ cos_scores = util.cos_sim(embeddings, embeddings).numpy()
37
+
38
+ # Compute the centrality for each sentence
39
+ centrality_scores = degree_centrality_scores(cos_scores, threshold=None)
40
 
41
+ # We argsort so that the first element is the sentence with the highest score
42
+ most_central_sentence_indices = numpy.argsort(-centrality_scores)
43
+ return most_central_sentence_indices
44
 
45
+ # ---===--- worker instances ---===---
46
  t_randeng = SummaryExtractor()
47
  embedder = Embed()
48
+ lex = LexRank()
49
 
50
 
51
  def randeng_extract(content):
52
+ sentences = lex.find_central(content)
53
+
54
+ num = 500
55
+ ptr = 0
56
+ for index, sentence in enumerate(sentences):
57
+ num -= len(sentence)
58
+ if num < 0 and index > 0:
59
+ ptr = index - 1
60
+ break
61
+ if num < 0 and index == 0:
62
+ ptr = index
63
+ break
64
+ print(">>>")
65
+ for ele in sentences[:ptr]:
66
+ print(ele)
67
+ return t_randeng.extract("".join(sentences[:ptr]))
68
 
69
 
70
  def similarity_check(inputs: list):
 
81
  # text_output = gr.Textbox()
82
  # text_button = gr.Button("生成摘要")
83
  with gr.Tab("Randeng-Pegasus-523M"):
84
+ text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
85
+ text_output_1 = gr.Textbox(label="摘要文本")
86
  text_button_1 = gr.Button("生成摘要")
87
  with gr.Tab("相似度检测"):
88
  with gr.Row():
89
+ text_input_query = gr.Textbox(label="查询文本")
90
+ text_input_doc = gr.Textbox(lines=10, label="逐行输入待比较的文本列表")
91
  text_button_similarity = gr.Button("对比相似度")
92
  text_output_similarity = gr.Textbox()
93
 
 
95
  text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
96
  text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
97
 
98
+ app.launch(
99
+ # share=True,
100
+ # debug=True
101
+ )
article_extractor/tokenizers_pegasus.py CHANGED
@@ -20,7 +20,7 @@ import sys
20
  # sys.path.append("../../../../")
21
 
22
  jieba.dt.tmp_dir = os.path.expanduser(
23
- "../tmp/")
24
  # jieba.enable_parallel(8)
25
  jieba.initialize()
26
 
 
20
  # sys.path.append("../../../../")
21
 
22
  jieba.dt.tmp_dir = os.path.expanduser(
23
+ "tmp/")
24
  # jieba.enable_parallel(8)
25
  jieba.initialize()
26
 
requirements.txt CHANGED
@@ -11,4 +11,8 @@ jieba
11
  deepspeed
12
  jieba-fast
13
  protobuf
14
- datasets
 
 
 
 
 
11
  deepspeed
12
  jieba-fast
13
  protobuf
14
+ datasets
15
+
16
+ gradio
17
+
18
+ sentence-transformers