hellopahe commited on
Commit
235b9c1
1 Parent(s): 1870c14
Files changed (4) hide show
  1. app.py +2 -11
  2. article_extractor/data_utils.py +321 -0
  3. embed.py +36 -0
  4. requirements.txt +3 -1
app.py CHANGED
@@ -1,7 +1,3 @@
1
- """
2
- This script summarizes from a paragraph(normally 512 length).
3
- """
4
-
5
  import torch
6
  import gradio as gr
7
 
@@ -14,16 +10,9 @@ class SummaryExtractor(object):
14
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
  self.model = PegasusForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese').to(self.device)
16
  self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
17
-
18
-
19
  self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device)
20
 
21
  def extract(self, content: str, min=20, max=30) -> str:
22
- # inputs = self.tokenizer(content, max_length=512, return_tensors='pt').to(device=self.device)
23
- # summary_ids = self.model.generate(inputs['input_ids'])
24
- # return self.tokenizer.batch_decode(summary_ids,
25
- # skip_special_tokens=True,
26
- # clean_up_tokenization_spaces=False)[0]
27
  return str(self.text2text_genr(content, do_sample=False, min_length=min, max_length=max, num_return_sequences=3)[0]["generated_text"])
28
 
29
 
@@ -34,6 +23,8 @@ def randeng_extract(content):
34
  return t_randeng.extract(content)
35
 
36
  def similarity_check(query: str, doc: str):
 
 
37
  return "similarity result"
38
 
39
  with gr.Blocks() as app:
 
 
 
 
 
1
  import torch
2
  import gradio as gr
3
 
 
10
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
  self.model = PegasusForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese').to(self.device)
12
  self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
 
 
13
  self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device)
14
 
15
  def extract(self, content: str, min=20, max=30) -> str:
 
 
 
 
 
16
  return str(self.text2text_genr(content, do_sample=False, min_length=min, max_length=max, num_return_sequences=3)[0]["generated_text"])
17
 
18
 
 
23
  return t_randeng.extract(content)
24
 
25
  def similarity_check(query: str, doc: str):
26
+ doc_list = doc.split("\n")
27
+
28
  return "similarity result"
29
 
30
  with gr.Blocks() as app:
article_extractor/data_utils.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # 用于
4
+
5
+ import re
6
+ import six
7
+ import unicodedata
8
+ import torch
9
+ import rouge
10
+ import numpy as np
11
+ import random
12
+ # from fengshen.examples.pegasus.pegasus_utils import text_segmentate
13
+ import sys
14
+
15
+ sys.path.append('../../../../')
16
+
17
+ rouge = rouge.Rouge()
18
+
19
+
20
+ is_py2 = six.PY2
21
+
22
+ if not is_py2:
23
+ basestring = str
24
+
25
+
26
+ def _is_chinese_char(cp):
27
+ """Checks whether CP is the codepoint of a CJK character."""
28
+ # This defines a "chinese character" as anything in the CJK Unicode block:
29
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
30
+ #
31
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
32
+ # despite its name. The modern Korean Hangul alphabet is a different block,
33
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
34
+ # space-separated words, so they are not treated specially and handled
35
+ # like the all of the other languages.
36
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
37
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
38
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
39
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
40
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
41
+ or (cp >= 0xF900 and cp <= 0xFAFF)
42
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)):
43
+ return True
44
+
45
+ return False
46
+
47
+
48
+ def _is_whitespace(char):
49
+ """Checks whether `char` is a whitespace character."""
50
+ # \t, \n, and \r are technically control characters but we treat them
51
+ # as whitespace since they are generally considered as such.
52
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
53
+ return True
54
+ cat = unicodedata.category(char)
55
+ if cat == "Zs":
56
+ return True
57
+ return False
58
+
59
+
60
+ def _is_control(char):
61
+ """Checks whether `char` is a control character."""
62
+ # These are technically control characters but we count them as whitespace
63
+ # characters.
64
+ if char == "\t" or char == "\n" or char == "\r":
65
+ return False
66
+ cat = unicodedata.category(char)
67
+ if cat.startswith("C"):
68
+ return True
69
+ return False
70
+
71
+
72
+ def _is_punctuation(char):
73
+ """Checks whether `char` is a punctuation character."""
74
+ cp = ord(char)
75
+ # We treat all non-letter/number ASCII as punctuation.
76
+ # Characters such as "^", "$", and "`" are not in the Unicode
77
+ # Punctuation class but we treat them as punctuation anyways, for
78
+ # consistency.
79
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (
80
+ cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
81
+ return True
82
+ cat = unicodedata.category(char)
83
+ if cat.startswith("P"):
84
+ return True
85
+ return False
86
+
87
+
88
+ def is_string(s):
89
+ """判断是否是字符串
90
+ """
91
+ return isinstance(s, basestring)
92
+
93
+
94
+ def is_stopwords(word, stopwords):
95
+ if word in stopwords:
96
+ return True
97
+ else:
98
+ return False
99
+
100
+
101
+ def text_segmentate(text):
102
+ en_seg_pattern = '((?:\\!|\\?|\\.|\\n)+(?:\\s)+)'
103
+ ch_seg_pattern = '((?:?|!|。|\\n)+)'
104
+ try:
105
+ text = re.sub(en_seg_pattern, r'\1[SEP]', text)
106
+ # print("sub text: ", text)
107
+ except Exception as e:
108
+ print("input: ", text)
109
+ raise e
110
+ text = re.sub(ch_seg_pattern, r'\1[SEP]', text)
111
+ # print("sub ch text: ", text)
112
+ text_list = text.split("[SEP]")
113
+ text_list = list(filter(lambda x: len(x) != 0, text_list))
114
+ return text_list
115
+
116
+
117
+ def load_stopwords(stopwords_path):
118
+ stopwords_dict = {}
119
+ with open(stopwords_path, "r") as rf:
120
+ for line in rf:
121
+ line = line.strip()
122
+ if line not in stopwords_dict:
123
+ stopwords_dict[line] = 0
124
+ else:
125
+ pass
126
+ return stopwords_dict
127
+
128
+
129
+ def text_process(text, max_length):
130
+ """分割文本
131
+ """
132
+ texts = text_segmentate(text)
133
+
134
+ result, length = [], 0
135
+ for text in texts:
136
+ if length + len(text) > max_length * 1.3 and len(result) >= 3:
137
+ yield result
138
+ result, length = [], 0
139
+ result.append(text)
140
+ length += len(text)
141
+ if result and len(result) >= 3:
142
+ yield result
143
+
144
+
145
+ def text_process_split_long_content(text, max_length):
146
+ """分割长文本
147
+ """
148
+ texts = text_segmentate(text)
149
+
150
+ result, sentence_num = "", 0
151
+ for text in texts:
152
+ if len(text) > 500:
153
+ if len(result) > 300 and sentence_num >= 3:
154
+ yield result
155
+ result, sentence_num = "", 0
156
+ else:
157
+ result, sentence_num = "", 0
158
+ continue
159
+ else:
160
+ if len(result) + len(text) > max_length * 1.1 and sentence_num >= 3:
161
+ yield result
162
+ result, sentence_num = "", 0
163
+ result += text
164
+ sentence_num += 1
165
+
166
+ if result and sentence_num >= 3:
167
+ yield result
168
+
169
+
170
+ def gather_join(texts, idxs):
171
+ """取出对应的text,然后拼接起来
172
+ """
173
+ return ''.join([texts[i] for i in idxs])
174
+
175
+
176
+ def gather_join_f1(texts_token, idsx):
177
+ join_texts = []
178
+ for id in idsx:
179
+ join_texts.extend(texts_token[id])
180
+ return join_texts
181
+
182
+
183
+ def compute_rouge(source, target):
184
+ """计算rouge-1、rouge-2、rouge-l
185
+ """
186
+ source, target = ' '.join(source), ' '.join(target)
187
+ try:
188
+ scores = rouge.get_scores(hyps=source, refs=target)
189
+ return {
190
+ 'rouge-1': scores[0]['rouge-1']['f'],
191
+ 'rouge-2': scores[0]['rouge-2']['f'],
192
+ 'rouge-l': scores[0]['rouge-l']['f'],
193
+ }
194
+ except ValueError:
195
+ return {
196
+ 'rouge-1': 0.0,
197
+ 'rouge-2': 0.0,
198
+ 'rouge-l': 0.0,
199
+ }
200
+
201
+
202
+ def remove_stopwords(texts, stopwords_dict):
203
+ for i, text in enumerate(texts):
204
+ texts[i] = list(filter(lambda x: x not in stopwords_dict, text))
205
+ return texts
206
+
207
+
208
+ def pseudo_summary_f1(texts,
209
+ stopwords,
210
+ tokenizer,
211
+ max_length,
212
+ rouge_strategy="rouge-l"):
213
+ """构建伪标签摘要数据集
214
+ """
215
+ summary_rate = 0.25
216
+ max_length = max_length - 1
217
+ texts_tokens = []
218
+ sentece_idxs_vec = []
219
+ for text in texts:
220
+ if len(texts) == 0:
221
+ continue
222
+ try:
223
+ ids = tokenizer.encode(text.strip())[:-1]
224
+ except ValueError:
225
+ print("error, input : ", text)
226
+ raise ValueError
227
+ sentece_idxs_vec.append(ids)
228
+ tokens = [tokenizer._convert_id_to_token(token) for token in ids]
229
+ texts_tokens.append(tokens)
230
+
231
+ texts_tokens_rm = remove_stopwords(texts_tokens, stopwords)
232
+ source_idxs, target_idxs = list(range(len(texts))), []
233
+
234
+ assert len(texts_tokens) == len(texts)
235
+ # truncate_index = 0
236
+ while True:
237
+ sims = []
238
+ for i in source_idxs:
239
+ new_source_idxs = [j for j in source_idxs if j != i]
240
+ new_target_idxs = sorted(target_idxs + [i])
241
+ new_source = gather_join_f1(texts_tokens_rm, new_source_idxs)
242
+ new_target = gather_join_f1(texts_tokens_rm, new_target_idxs)
243
+ sim = compute_rouge(new_source, new_target)[rouge_strategy]
244
+ sims.append(sim)
245
+ new_idx = source_idxs[np.argmax(sims)]
246
+ del sims
247
+ source_idxs.remove(new_idx)
248
+ target_idxs = sorted(target_idxs + [new_idx])
249
+ source = gather_join(texts, source_idxs)
250
+ target = gather_join(texts, target_idxs)
251
+ try:
252
+ if (len(source_idxs) == 1
253
+ or 1.0 * len(target) / len(source) > summary_rate):
254
+ break
255
+ except ZeroDivisionError as e:
256
+ print(e.meesage)
257
+ print(texts)
258
+ print("source: ", source)
259
+ print("target: ", target)
260
+
261
+ if len(source) < len(target):
262
+ source, target = target, source
263
+ source_idxs, target_idxs = target_idxs, source_idxs
264
+
265
+ return sentece_idxs_vec, source, target, source_idxs, target_idxs
266
+
267
+
268
+ def get_input_mask(sentence_id_vec, indexs):
269
+ target_idxs = []
270
+ input_idxs = []
271
+ kMaskSentenceTokenId = 2
272
+ kEosTokenId = 1
273
+ mask_sentence_options_cumulative_prob = [0.9, 0.9, 1, 1]
274
+ for index in indexs:
275
+ target_idxs.extend(sentence_id_vec[index])
276
+ choice = random.uniform(0, 1)
277
+ if choice < mask_sentence_options_cumulative_prob[0]:
278
+ # print("mask index: ", index)
279
+ sentence_id_vec[index] = [kMaskSentenceTokenId]
280
+ elif choice < mask_sentence_options_cumulative_prob[1]:
281
+ # print("replace index: ", index)
282
+ replace_id = random.randint(0, len(sentence_id_vec))
283
+ sentence_id_vec[index] = sentence_id_vec[replace_id]
284
+ elif choice < mask_sentence_options_cumulative_prob[2]:
285
+ pass
286
+ else:
287
+ sentence_id_vec[index] = []
288
+
289
+ target_idxs.append(kEosTokenId)
290
+ # print(sentence_id_vec)
291
+ for index, sentence_id in enumerate(sentence_id_vec):
292
+ # print(index, sentence_id)
293
+ if len(sentence_id) == 0:
294
+ continue
295
+ input_idxs.extend(sentence_id_vec[index])
296
+
297
+ input_idxs.append(kEosTokenId)
298
+ return input_idxs, target_idxs
299
+
300
+
301
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
302
+ decoder_start_token_id: int):
303
+ """
304
+ Shift input ids one token to the right.
305
+ """
306
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
307
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
308
+ shifted_input_ids[:, 0] = decoder_start_token_id
309
+
310
+ if pad_token_id is None:
311
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
312
+ # replace possible -100 values in labels by `pad_token_id`
313
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
314
+
315
+ return shifted_input_ids
316
+
317
+
318
+ def padding_to_maxlength(ids, max_length, pad_id):
319
+ cur_len = len(ids)
320
+ len_diff = max_length - cur_len
321
+ return ids + [pad_id] * len_diff, [1] * cur_len + [0] * len_diff
embed.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script turn list of string into embeddings.
3
+ """
4
+ from transformers import AutoTokenizer, TFAutoModel
5
+ import tensorflow as tf
6
+
7
+
8
+ class Embed(object):
9
+ def __init__(self):
10
+ self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
11
+ self.model = TFAutoModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
12
+
13
+ @staticmethod
14
+ # Mean Pooling - Take attention mask into account for correct averaging
15
+ def mean_pooling(model_output, attention_mask):
16
+ token_embeddings = model_output.last_hidden_state
17
+ input_mask_expanded = tf.cast(tf.tile(tf.expand_dims(attention_mask, -1), [1, 1, token_embeddings.shape[-1]]),
18
+ tf.float32)
19
+ return tf.math.reduce_sum(token_embeddings * input_mask_expanded, 1) / tf.math.maximum(
20
+ tf.math.reduce_sum(input_mask_expanded, 1), 1e-9)
21
+
22
+ # Encode text
23
+ def encode(self, texts):
24
+ # Tokenize sentences
25
+ encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='tf')
26
+
27
+ # Compute token embeddings
28
+ model_output = self.model(**encoded_input, return_dict=True)
29
+
30
+ # Perform pooling
31
+ embeddings = Embed.mean_pooling(model_output, encoded_input['attention_mask'])
32
+
33
+ # Normalize embeddings
34
+ embeddings = tf.math.l2_normalize(embeddings, axis=1)
35
+
36
+ return embeddings
requirements.txt CHANGED
@@ -4,4 +4,6 @@ tensorflow
4
  transformers
5
  peft
6
  pandas
7
- sentence-transformers
 
 
 
4
  transformers
5
  peft
6
  pandas
7
+ sentence-transformers
8
+
9
+ rouge