hellopahe commited on
Commit
93e5f33
1 Parent(s): 7090141

加入中英文分句的判断

Browse files
app.py CHANGED
@@ -1,126 +1,25 @@
1
- import math
2
 
3
- import numpy
4
- import torch
5
- import gradio as gr
6
-
7
- from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline, AutoModel, AutoTokenizer
8
- from article_extractor.tokenizers_pegasus import PegasusTokenizer
9
-
10
- import tensorflow as tf
11
-
12
- from harvesttext import HarvestText
13
  from sentence_transformers import SentenceTransformer, util
14
- from LexRank import degree_centrality_scores
15
-
16
- from luotuo_util import DeviceMap
17
- from peft import get_peft_model, LoraConfig, TaskType
18
-
19
-
20
- class SummaryExtractor(object):
21
- def __init__(self):
22
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
- self.model = PegasusForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese').to(self.device)
24
- self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
25
- self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device)
26
-
27
- def extract(self, content: str) -> str:
28
- print(content)
29
- return str(self.text2text_genr(content, do_sample=False, num_return_sequences=3)[0]["generated_text"])
30
-
31
- class Tuoling_6B_extractor(object):
32
- def __init__(self):
33
- torch.set_default_tensor_type(torch.cuda.HalfTensor)
34
- self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
35
- self.model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, device_map=DeviceMap("ChatGLM").get())
36
-
37
- # load fine-tuned pretrained model.
38
- peft_path = "./luotuoC.pt"
39
- peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=True, r=8, lora_alpha=32, lora_dropout=0.1)
40
- self.model = get_peft_model(self.model, peft_config)
41
- self.model.load_state_dict(torch.load(peft_path), strict=False)
42
- torch.set_default_tensor_type(torch.cuda.FloatTensor)
43
-
44
- @staticmethod
45
- def format_example(example: dict) -> dict:
46
- context = f"Instruction: {example['instruction']}\n"
47
- if example.get("input"):
48
- context += f"Input: {example['input']}\n"
49
- context += "Answer: "
50
- target = example["output"]
51
- return {"context": context, "target": target}
52
 
53
- def extract(self, instruction: str, input=None) -> str:
54
- with torch.no_grad():
55
- feature = Tuoling_6B_extractor.format_example(
56
- {"instruction": "请帮我总结以下内容", "output": "", "input": f"{instruction}"}
57
- )
58
- input_text = feature["context"]
59
- input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
60
- out = self.model.generate(input_ids=input_ids, max_length=2048, temperature=0)
61
- answer = self.tokenizer.decode(out[0])
62
- return answer.split('Answer:')[1]
63
 
64
- class LexRank(object):
65
- def __init__(self):
66
- self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
67
- self.ht = HarvestText()
68
-
69
- def find_central(self, content: str, num=100):
70
- sentences = self.ht.cut_sentences(content)
71
- embeddings = self.model.encode(sentences, convert_to_tensor=True).cpu()
72
-
73
- # Compute the pair-wise cosine similarities
74
- cos_scores = util.cos_sim(embeddings, embeddings).numpy()
75
-
76
- # Compute the centrality for each sentence
77
- centrality_scores = degree_centrality_scores(cos_scores, threshold=None)
78
-
79
- # We argsort so that the first element is the sentence with the highest score
80
- most_central_sentence_indices = numpy.argsort(-centrality_scores)
81
-
82
- # num = 100
83
- res = []
84
- for index in most_central_sentence_indices:
85
- if num < 0:
86
- break
87
- res.append(sentences[index])
88
- num -= len(sentences[index])
89
- return res
90
-
91
- # ---===--- worker instances ---===---
92
- # t_randeng = SummaryExtractor()
93
- # t_tuoling = Tuoling_6B_extractor()
94
-
95
- # embedder = Embed()
96
  embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
97
  lex = LexRank()
98
 
99
 
100
- def randeng_extract(content):
 
101
  summary_length = math.ceil(len(content) / 10)
102
  sentences = lex.find_central(content, num=summary_length)
103
  output = ""
104
  for index, sentence in enumerate(sentences):
105
  output += f"{index}: {sentence}\n"
106
- # output += "摘要:\n"
107
- # for index, sentence in enumerate(sentences):
108
- # output += f"{index}: {t_randeng.extract(sentence)}\n"
109
  return output
110
 
111
- # def tuoling_extract(content):
112
- # sentences = lex.find_central(content)
113
- # return str(list(t_tuoling.extract(sentence) for sentence in sentences))
114
-
115
- def similarity_check(query, doc):
116
- doc_list = doc.split("\n")
117
-
118
- query_embedding = embedder.encode(query)
119
- doc_embedding = embedder.encode(doc_list)
120
- scores = (query_embedding @ tf.transpose(doc_embedding))[0].numpy().tolist()
121
- # scores = list(util.cos_sim(embedding_list[-1], doc_embedding) for doc_embedding in embedding_list[:-1])
122
- return str(scores)
123
 
 
124
  def similarity_search(queries, doc):
125
  doc_list = doc.split('\n')
126
  query_list = queries.split('\n')
@@ -141,17 +40,13 @@ def similarity_search(queries, doc):
141
  return output
142
 
143
 
 
144
  with gr.Blocks() as app:
145
  gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
146
  with gr.Tab("LexRank"):
147
  text_input_1 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
148
  text_button_1 = gr.Button("生成摘要")
149
  text_output_1 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
150
-
151
- # with gr.Tab("LexRank->Tuoling-6B-chatGLM"):
152
- # text_input = gr.Textbox(label="请输入长文本:", max_lines=1000)
153
- # text_output = gr.Textbox(label="摘要文本")
154
- # text_button = gr.Button("生成摘要")
155
  with gr.Tab("相似度检测"):
156
  with gr.Row():
157
  text_input_query = gr.Textbox(lines=10, label="查询文本")
@@ -159,11 +54,13 @@ with gr.Blocks() as app:
159
  text_button_similarity = gr.Button("对比相似度")
160
  text_output_similarity = gr.Textbox()
161
 
162
- # text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
163
- text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
164
  text_button_similarity.click(similarity_search, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
165
 
166
  app.launch(
 
167
  # share=True,
168
- # debug=True
 
 
169
  )
 
1
+ import math, torch, gradio as gr
2
 
3
+ from lex_rank import LexRank
 
 
 
 
 
 
 
 
 
4
  from sentence_transformers import SentenceTransformer, util
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # ---===--- instances ---===---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
9
  lex = LexRank()
10
 
11
 
12
+ # 摘要方法
13
+ def extract_handler(content):
14
  summary_length = math.ceil(len(content) / 10)
15
  sentences = lex.find_central(content, num=summary_length)
16
  output = ""
17
  for index, sentence in enumerate(sentences):
18
  output += f"{index}: {sentence}\n"
 
 
 
19
  return output
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # 相似度检测方法
23
  def similarity_search(queries, doc):
24
  doc_list = doc.split('\n')
25
  query_list = queries.split('\n')
 
40
  return output
41
 
42
 
43
+ # web ui
44
  with gr.Blocks() as app:
45
  gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
46
  with gr.Tab("LexRank"):
47
  text_input_1 = gr.Textbox(label="请输入长文本:", lines=10, max_lines=1000)
48
  text_button_1 = gr.Button("生成摘要")
49
  text_output_1 = gr.Textbox(label="摘要文本(长度设置为原文长度的1/10)", lines=10)
 
 
 
 
 
50
  with gr.Tab("相似度检测"):
51
  with gr.Row():
52
  text_input_query = gr.Textbox(lines=10, label="查询文本")
 
54
  text_button_similarity = gr.Button("对比相似度")
55
  text_output_similarity = gr.Textbox()
56
 
57
+ text_button_1.click(extract_handler, inputs=text_input_1, outputs=text_output_1)
 
58
  text_button_similarity.click(similarity_search, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
59
 
60
  app.launch(
61
+ # enable share will generate a temporary public link.
62
  # share=True,
63
+ # debug=True,
64
+ auth=("qee", "world"),
65
+ auth_message="请登陆"
66
  )
article_extractor/data_utils.py DELETED
@@ -1,321 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # 用于
4
-
5
- import re
6
- import six
7
- import unicodedata
8
- import torch
9
- import rouge
10
- import numpy as np
11
- import random
12
- # from fengshen.examples.pegasus.pegasus_utils import text_segmentate
13
- import sys
14
-
15
- sys.path.append('../../../../')
16
-
17
- rouge = rouge.Rouge()
18
-
19
-
20
- is_py2 = six.PY2
21
-
22
- if not is_py2:
23
- basestring = str
24
-
25
-
26
- def _is_chinese_char(cp):
27
- """Checks whether CP is the codepoint of a CJK character."""
28
- # This defines a "chinese character" as anything in the CJK Unicode block:
29
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
30
- #
31
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
32
- # despite its name. The modern Korean Hangul alphabet is a different block,
33
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
34
- # space-separated words, so they are not treated specially and handled
35
- # like the all of the other languages.
36
- if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
37
- or (cp >= 0x20000 and cp <= 0x2A6DF)
38
- or (cp >= 0x2A700 and cp <= 0x2B73F)
39
- or (cp >= 0x2B740 and cp <= 0x2B81F)
40
- or (cp >= 0x2B820 and cp <= 0x2CEAF)
41
- or (cp >= 0xF900 and cp <= 0xFAFF)
42
- or (cp >= 0x2F800 and cp <= 0x2FA1F)):
43
- return True
44
-
45
- return False
46
-
47
-
48
- def _is_whitespace(char):
49
- """Checks whether `char` is a whitespace character."""
50
- # \t, \n, and \r are technically control characters but we treat them
51
- # as whitespace since they are generally considered as such.
52
- if char == " " or char == "\t" or char == "\n" or char == "\r":
53
- return True
54
- cat = unicodedata.category(char)
55
- if cat == "Zs":
56
- return True
57
- return False
58
-
59
-
60
- def _is_control(char):
61
- """Checks whether `char` is a control character."""
62
- # These are technically control characters but we count them as whitespace
63
- # characters.
64
- if char == "\t" or char == "\n" or char == "\r":
65
- return False
66
- cat = unicodedata.category(char)
67
- if cat.startswith("C"):
68
- return True
69
- return False
70
-
71
-
72
- def _is_punctuation(char):
73
- """Checks whether `char` is a punctuation character."""
74
- cp = ord(char)
75
- # We treat all non-letter/number ASCII as punctuation.
76
- # Characters such as "^", "$", and "`" are not in the Unicode
77
- # Punctuation class but we treat them as punctuation anyways, for
78
- # consistency.
79
- if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (
80
- cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
81
- return True
82
- cat = unicodedata.category(char)
83
- if cat.startswith("P"):
84
- return True
85
- return False
86
-
87
-
88
- def is_string(s):
89
- """判断是否是字符串
90
- """
91
- return isinstance(s, basestring)
92
-
93
-
94
- def is_stopwords(word, stopwords):
95
- if word in stopwords:
96
- return True
97
- else:
98
- return False
99
-
100
-
101
- def text_segmentate(text):
102
- en_seg_pattern = '((?:\\!|\\?|\\.|\\n)+(?:\\s)+)'
103
- ch_seg_pattern = '((?:?|!|。|\\n)+)'
104
- try:
105
- text = re.sub(en_seg_pattern, r'\1[SEP]', text)
106
- # print("sub text: ", text)
107
- except Exception as e:
108
- print("input: ", text)
109
- raise e
110
- text = re.sub(ch_seg_pattern, r'\1[SEP]', text)
111
- # print("sub ch text: ", text)
112
- text_list = text.split("[SEP]")
113
- text_list = list(filter(lambda x: len(x) != 0, text_list))
114
- return text_list
115
-
116
-
117
- def load_stopwords(stopwords_path):
118
- stopwords_dict = {}
119
- with open(stopwords_path, "r") as rf:
120
- for line in rf:
121
- line = line.strip()
122
- if line not in stopwords_dict:
123
- stopwords_dict[line] = 0
124
- else:
125
- pass
126
- return stopwords_dict
127
-
128
-
129
- def text_process(text, max_length):
130
- """分割文本
131
- """
132
- texts = text_segmentate(text)
133
-
134
- result, length = [], 0
135
- for text in texts:
136
- if length + len(text) > max_length * 1.3 and len(result) >= 3:
137
- yield result
138
- result, length = [], 0
139
- result.append(text)
140
- length += len(text)
141
- if result and len(result) >= 3:
142
- yield result
143
-
144
-
145
- def text_process_split_long_content(text, max_length):
146
- """分割长文本
147
- """
148
- texts = text_segmentate(text)
149
-
150
- result, sentence_num = "", 0
151
- for text in texts:
152
- if len(text) > 500:
153
- if len(result) > 300 and sentence_num >= 3:
154
- yield result
155
- result, sentence_num = "", 0
156
- else:
157
- result, sentence_num = "", 0
158
- continue
159
- else:
160
- if len(result) + len(text) > max_length * 1.1 and sentence_num >= 3:
161
- yield result
162
- result, sentence_num = "", 0
163
- result += text
164
- sentence_num += 1
165
-
166
- if result and sentence_num >= 3:
167
- yield result
168
-
169
-
170
- def gather_join(texts, idxs):
171
- """取出对应的text,然后拼接起来
172
- """
173
- return ''.join([texts[i] for i in idxs])
174
-
175
-
176
- def gather_join_f1(texts_token, idsx):
177
- join_texts = []
178
- for id in idsx:
179
- join_texts.extend(texts_token[id])
180
- return join_texts
181
-
182
-
183
- def compute_rouge(source, target):
184
- """计算rouge-1、rouge-2、rouge-l
185
- """
186
- source, target = ' '.join(source), ' '.join(target)
187
- try:
188
- scores = rouge.get_scores(hyps=source, refs=target)
189
- return {
190
- 'rouge-1': scores[0]['rouge-1']['f'],
191
- 'rouge-2': scores[0]['rouge-2']['f'],
192
- 'rouge-l': scores[0]['rouge-l']['f'],
193
- }
194
- except ValueError:
195
- return {
196
- 'rouge-1': 0.0,
197
- 'rouge-2': 0.0,
198
- 'rouge-l': 0.0,
199
- }
200
-
201
-
202
- def remove_stopwords(texts, stopwords_dict):
203
- for i, text in enumerate(texts):
204
- texts[i] = list(filter(lambda x: x not in stopwords_dict, text))
205
- return texts
206
-
207
-
208
- def pseudo_summary_f1(texts,
209
- stopwords,
210
- tokenizer,
211
- max_length,
212
- rouge_strategy="rouge-l"):
213
- """构建伪标签摘要数据集
214
- """
215
- summary_rate = 0.25
216
- max_length = max_length - 1
217
- texts_tokens = []
218
- sentece_idxs_vec = []
219
- for text in texts:
220
- if len(texts) == 0:
221
- continue
222
- try:
223
- ids = tokenizer.encode(text.strip())[:-1]
224
- except ValueError:
225
- print("error, input : ", text)
226
- raise ValueError
227
- sentece_idxs_vec.append(ids)
228
- tokens = [tokenizer._convert_id_to_token(token) for token in ids]
229
- texts_tokens.append(tokens)
230
-
231
- texts_tokens_rm = remove_stopwords(texts_tokens, stopwords)
232
- source_idxs, target_idxs = list(range(len(texts))), []
233
-
234
- assert len(texts_tokens) == len(texts)
235
- # truncate_index = 0
236
- while True:
237
- sims = []
238
- for i in source_idxs:
239
- new_source_idxs = [j for j in source_idxs if j != i]
240
- new_target_idxs = sorted(target_idxs + [i])
241
- new_source = gather_join_f1(texts_tokens_rm, new_source_idxs)
242
- new_target = gather_join_f1(texts_tokens_rm, new_target_idxs)
243
- sim = compute_rouge(new_source, new_target)[rouge_strategy]
244
- sims.append(sim)
245
- new_idx = source_idxs[np.argmax(sims)]
246
- del sims
247
- source_idxs.remove(new_idx)
248
- target_idxs = sorted(target_idxs + [new_idx])
249
- source = gather_join(texts, source_idxs)
250
- target = gather_join(texts, target_idxs)
251
- try:
252
- if (len(source_idxs) == 1
253
- or 1.0 * len(target) / len(source) > summary_rate):
254
- break
255
- except ZeroDivisionError as e:
256
- print(e.meesage)
257
- print(texts)
258
- print("source: ", source)
259
- print("target: ", target)
260
-
261
- if len(source) < len(target):
262
- source, target = target, source
263
- source_idxs, target_idxs = target_idxs, source_idxs
264
-
265
- return sentece_idxs_vec, source, target, source_idxs, target_idxs
266
-
267
-
268
- def get_input_mask(sentence_id_vec, indexs):
269
- target_idxs = []
270
- input_idxs = []
271
- kMaskSentenceTokenId = 2
272
- kEosTokenId = 1
273
- mask_sentence_options_cumulative_prob = [0.9, 0.9, 1, 1]
274
- for index in indexs:
275
- target_idxs.extend(sentence_id_vec[index])
276
- choice = random.uniform(0, 1)
277
- if choice < mask_sentence_options_cumulative_prob[0]:
278
- # print("mask index: ", index)
279
- sentence_id_vec[index] = [kMaskSentenceTokenId]
280
- elif choice < mask_sentence_options_cumulative_prob[1]:
281
- # print("replace index: ", index)
282
- replace_id = random.randint(0, len(sentence_id_vec))
283
- sentence_id_vec[index] = sentence_id_vec[replace_id]
284
- elif choice < mask_sentence_options_cumulative_prob[2]:
285
- pass
286
- else:
287
- sentence_id_vec[index] = []
288
-
289
- target_idxs.append(kEosTokenId)
290
- # print(sentence_id_vec)
291
- for index, sentence_id in enumerate(sentence_id_vec):
292
- # print(index, sentence_id)
293
- if len(sentence_id) == 0:
294
- continue
295
- input_idxs.extend(sentence_id_vec[index])
296
-
297
- input_idxs.append(kEosTokenId)
298
- return input_idxs, target_idxs
299
-
300
-
301
- def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
302
- decoder_start_token_id: int):
303
- """
304
- Shift input ids one token to the right.
305
- """
306
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
307
- shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
308
- shifted_input_ids[:, 0] = decoder_start_token_id
309
-
310
- if pad_token_id is None:
311
- raise ValueError("self.model.config.pad_token_id has to be defined.")
312
- # replace possible -100 values in labels by `pad_token_id`
313
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
314
-
315
- return shifted_input_ids
316
-
317
-
318
- def padding_to_maxlength(ids, max_length, pad_id):
319
- cur_len = len(ids)
320
- len_diff = max_length - cur_len
321
- return ids + [pad_id] * len_diff, [1] * cur_len + [0] * len_diff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
article_extractor/tokenizers_pegasus.py DELETED
@@ -1,602 +0,0 @@
1
- import sys
2
- sys.path.append('../')
3
- from article_extractor.data_utils import (
4
- _is_control,
5
- _is_punctuation,
6
- _is_whitespace,
7
- _is_chinese_char)
8
- from transformers import PreTrainedTokenizer
9
- from transformers import logging
10
- from typing import List, Optional, Tuple, Union
11
- import collections
12
- import os
13
- import unicodedata
14
- import re
15
- import jieba
16
- import sys
17
-
18
- # 提取摘要逻辑实现
19
-
20
- # sys.path.append("../../../../")
21
-
22
- jieba.dt.tmp_dir = os.path.expanduser(
23
- "tmp/")
24
- # jieba.enable_parallel(8)
25
- jieba.initialize()
26
-
27
- logger = logging.get_logger(__name__)
28
-
29
- VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
-
31
-
32
- def load_vocab(vocab_file):
33
- """Loads a vocabulary file into a dictionary."""
34
- vocab = collections.OrderedDict()
35
- with open(vocab_file, "r", encoding="utf-8") as reader:
36
- tokens = reader.readlines()
37
- for index, token in enumerate(tokens):
38
- token = token.rstrip("\n")
39
- vocab[token] = index
40
- return vocab
41
-
42
-
43
- def whitespace_tokenize(text):
44
- """Runs basic whitespace cleaning and splitting on a piece of text."""
45
- text = text.strip()
46
- if not text:
47
- return []
48
- tokens = text.split()
49
- return tokens
50
-
51
-
52
- class PegasusTokenizer(PreTrainedTokenizer):
53
- # copy from BertTokenizer
54
- r"""
55
- Construct a Pegasus tokenizer. Based on WordPiece.
56
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
57
- this superclass for more information regarding those methods.
58
- Args:
59
- vocab_file (`str`):
60
- File containing the vocabulary.
61
- do_lower_case (`bool`, *optional*, defaults to `True`):
62
- Whether or not to lowercase the input when tokenizing.
63
- do_basic_tokenize (`bool`, *optional*, defaults to `True`):
64
- Whether or not to do basic tokenization before WordPiece.
65
- never_split (`Iterable`, *optional*):
66
- Collection of tokens which will never be split during tokenization. Only has an effect when
67
- `do_basic_tokenize=True`
68
- unk_token (`str`, *optional*, defaults to `"[UNK]"`):
69
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
70
- token instead.
71
- sep_token (`str`, *optional*, defaults to `"[SEP]"`):
72
- The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
73
- sequence classification or for a text and a question for question answering. It is also used as the last
74
- token of a sequence built with special tokens.
75
- pad_token (`str`, *optional*, defaults to `"[PAD]"`):
76
- The token used for padding, for example when batching sequences of different lengths.
77
- cls_token (`str`, *optional*, defaults to `"[CLS]"`):
78
- The classifier token which is used when doing sequence classification (classification of the whole sequence
79
- instead of per-token classification). It is the first token of the sequence when built with special tokens.
80
- mask_token (`str`, *optional*, defaults to `"[MASK]"`):
81
- The token used for masking values. This is the token used when training this model with masked language
82
- modeling. This is the token which the model will try to predict.
83
- tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
84
- Whether or not to tokenize Chinese characters.
85
- This should likely be deactivated for Japanese (see this
86
- [issue](https://github.com/huggingface/transformers/issues/328)).
87
- strip_accents (`bool`, *optional*):
88
- Whether or not to strip all accents. If this option is not specified, then it will be determined by the
89
- value for `lowercase` (as in the original BERT).
90
- """
91
-
92
- vocab_files_names = VOCAB_FILES_NAMES
93
- model_input_names = ["input_ids", "attention_mask"]
94
-
95
- # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
96
- # pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
97
- # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
-
99
- def __init__(self,
100
- vocab_file,
101
- do_lower_case=True,
102
- do_basic_tokenize=True,
103
- never_split=None,
104
- pad_token="<pad>",
105
- eos_token="</s>",
106
- unk_token="<unk>",
107
- mask_token="<mask_2>",
108
- mask_token_sent="<mask_1>",
109
- additional_special_tokens=None,
110
- sep_token="[SEP]",
111
- cls_token="[CLS]",
112
- tokenize_chinese_chars=True,
113
- strip_accents=None,
114
- offset=100,
115
- pre_tokenizer=lambda x: jieba.cut(x, HMM=False),
116
- **kwargs):
117
- self.offset = offset
118
-
119
- if additional_special_tokens is not None:
120
- if not isinstance(additional_special_tokens, list):
121
- raise TypeError(
122
- f"additional_special_tokens should be of type {type(list)}, \
123
- but is {type(additional_special_tokens)}"
124
- )
125
-
126
- additional_special_tokens_extended = (
127
- ([mask_token_sent] + additional_special_tokens)
128
- if mask_token_sent not in additional_special_tokens
129
- and mask_token_sent is not None else additional_special_tokens)
130
-
131
- # fill additional tokens with ..., <unk_token_102> in case not all additional tokens are already taken
132
- additional_special_tokens_extended += [
133
- f"<unk_{i}>" for i in range(
134
- len(additional_special_tokens_extended), self.offset - 1)
135
- ]
136
-
137
- if len(set(additional_special_tokens_extended)) != len(
138
- additional_special_tokens_extended):
139
- raise ValueError(
140
- f"Please make sure that the provided additional_special_tokens \
141
- do not contain an incorrectly shifted list of <unk_x> tokens. \
142
- Found {additional_special_tokens_extended}."
143
- )
144
- additional_special_tokens = additional_special_tokens_extended
145
- else:
146
- additional_special_tokens = [
147
- mask_token_sent
148
- ] if mask_token_sent is not None else []
149
- # additional_special_tokens += [f"<unk_{i}>" for i in range(3, self.offset)]
150
-
151
- # print("additional_special_tokens: ", additional_special_tokens)
152
-
153
- if not os.path.isfile(vocab_file):
154
- raise ValueError(
155
- f"Can't find a vocabulary file at path '{vocab_file}'. \
156
- To load the vocabulary from a Google pretrained "
157
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
158
- )
159
-
160
- super().__init__(
161
- do_lower_case=do_lower_case,
162
- do_basic_tokenize=do_basic_tokenize,
163
- never_split=never_split,
164
- unk_token=unk_token,
165
- sep_token=sep_token,
166
- pad_token=pad_token,
167
- cls_token=cls_token,
168
- mask_token=mask_token,
169
- eos_token=eos_token,
170
- tokenize_chinese_chars=tokenize_chinese_chars,
171
- additional_special_tokens=additional_special_tokens,
172
- strip_accents=strip_accents,
173
- **kwargs,
174
- )
175
-
176
- self.pre_tokenizer = pre_tokenizer
177
- self.mask_token_sent = mask_token_sent
178
- self.vocab = load_vocab(vocab_file)
179
-
180
- self.vocab[self.eos_token] = self.vocab.pop("[unused1]")
181
- # self.vocab[self.eos_token] = self.vocab.pop("[unused2]")
182
- self.vocab[self.pad_token] = self.vocab.pop("[PAD]")
183
- self.vocab[self.unk_token] = self.vocab.pop("[UNK]")
184
-
185
- if self.mask_token_sent is not None:
186
- self.vocab[self.mask_token] = self.vocab.pop("[unused3]")
187
- self.vocab[self.mask_token_sent] = self.vocab.pop("[unused2]")
188
-
189
- self.ids_to_tokens = collections.OrderedDict([
190
- (ids, tok) for tok, ids in self.vocab.items()
191
- ])
192
- self.do_basic_tokenize = do_basic_tokenize
193
- if do_basic_tokenize:
194
- self.basic_tokenizer = BasicTokenizer(
195
- do_lower_case=do_lower_case,
196
- never_split=never_split,
197
- tokenize_chinese_chars=tokenize_chinese_chars,
198
- strip_accents=strip_accents,
199
- )
200
- self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
201
- unk_token=self.unk_token)
202
-
203
- @property
204
- def do_lower_case(self):
205
- return self.basic_tokenizer.do_lower_case
206
-
207
- @property
208
- def vocab_size(self):
209
- return len(self.vocab)
210
-
211
- def get_vocab(self):
212
- return dict(self.vocab, **self.added_tokens_encoder)
213
-
214
- def _tokenize(self, text):
215
- split_tokens = []
216
- # print("pegasus_tokenizer: ", text)
217
- for text in self.pre_tokenizer(text):
218
- if text in self.vocab:
219
- split_tokens.append(text)
220
- else:
221
- if self.do_basic_tokenize:
222
- for token in self.basic_tokenizer.tokenize(
223
- text, never_split=self.all_special_tokens):
224
-
225
- # If the token is part of the never_split set
226
- if token in self.basic_tokenizer.never_split:
227
- split_tokens.append(token)
228
- else:
229
- split_tokens += self.wordpiece_tokenizer.tokenize(
230
- token)
231
- else:
232
- split_tokens = self.wordpiece_tokenizer.tokenize(text)
233
- return split_tokens
234
-
235
- def _convert_token_to_id(self, token):
236
- """Converts a token (str) in an id using the vocab."""
237
- return self.vocab.get(token, self.vocab.get(self.unk_token))
238
-
239
- def _convert_id_to_token(self, index):
240
- """Converts an index (integer) in a token (str) using the vocab."""
241
- return self.ids_to_tokens.get(index, self.unk_token)
242
-
243
- @staticmethod
244
- def _cjk_punctuation():
245
- return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\
246
- \uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\
247
- \uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\
248
- \u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\
249
- \u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'
250
-
251
- def convert_ids_to_tokens(
252
- self,
253
- ids: Union[int, List[int]],
254
- skip_special_tokens: bool = False) -> Union[str, List[str]]:
255
- """
256
- Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
257
- added tokens.
258
- Args:
259
- ids (`int` or `List[int]`):
260
- The token id (or token ids) to convert to tokens.
261
- skip_special_tokens (`bool`, *optional*, defaults to `False`):
262
- Whether or not to remove special tokens in the decoding.
263
- Returns:
264
- `str` or `List[str]`: The decoded token(s).
265
- """
266
- if isinstance(ids, int):
267
- if ids in self.added_tokens_decoder:
268
- return self.added_tokens_decoder[ids]
269
- else:
270
- return self._convert_id_to_token(ids)
271
- tokens = []
272
- for index in ids:
273
- index = int(index)
274
- if skip_special_tokens and index in self.all_special_ids and index != 2:
275
- continue
276
- if index in self.added_tokens_decoder:
277
- tokens.append(self.added_tokens_decoder[index])
278
- else:
279
- tokens.append(self._convert_id_to_token(index))
280
- return tokens
281
-
282
- def convert_tokens_to_string(self, tokens):
283
- """Converts a sequence of tokens (string) in a single string."""
284
- # for token in
285
- # tokens = tokens or self.ids_to_tokens(ids)
286
- # tokens = [token for token in tokens if not self._is_special(token)]
287
-
288
- text = ''
289
- for i, token in enumerate(tokens):
290
- if token[:2] == '##':
291
- text += token[2:]
292
- elif len(token) == 1 and _is_chinese_char(ord(token)):
293
- text += token
294
- elif len(token) == 1 and _is_punctuation(token):
295
- text += token
296
- text += ' '
297
- elif i > 0 and _is_chinese_char(ord(text[-1])):
298
- text += token
299
- elif tokens == "</s>":
300
- continue
301
- else:
302
- text += ' '
303
- text += token
304
-
305
- text = re.sub(' +', ' ', text)
306
- text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
307
- punctuation = re.sub(' +', '', self._cjk_punctuation()).strip() + '+-/={(<['
308
- punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
309
- punctuation_regex = '(%s) ' % punctuation_regex
310
- text = re.sub(punctuation_regex, '\\1', text)
311
- text = re.sub(r'(\d\.) (\d)', '\\1\\2', text)
312
-
313
- return text.strip()
314
- # out_string = " ".join(tokens).replace(" ##", "").strip()
315
-
316
- def build_inputs_with_special_tokens(
317
- self,
318
- token_ids_0: List[int],
319
- token_ids_1: Optional[List[int]] = None) -> List[int]:
320
- """
321
- Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
322
- and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:
323
- - single sequence: `X </s>`
324
- - pair of sequences: `A B </s>` (not intended use)
325
- BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
326
- separator.
327
- Args:
328
- token_ids_0 (`List[int]`):
329
- List of IDs to which the special tokens will be added.
330
- token_ids_1 (`List[int]`, *optional*):
331
- Optional second list of IDs for sequence pairs.
332
- Returns:
333
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
334
- """
335
- if token_ids_1 is None:
336
- return token_ids_0 + [self.eos_token_id]
337
- return token_ids_0 + token_ids_1 + [self.eos_token_id]
338
-
339
- def _special_token_mask(self, seq):
340
- all_special_ids = set(
341
- self.all_special_ids) # call it once instead of inside list comp
342
- # all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
343
-
344
- return [1 if x in all_special_ids else 0 for x in seq]
345
-
346
- def get_special_tokens_mask(
347
- self,
348
- token_ids_0: List[int],
349
- token_ids_1: Optional[List[int]] = None,
350
- already_has_special_tokens: bool = False) -> List[int]:
351
- """
352
- Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
353
- special tokens using the tokenizer `prepare_for_model` method.
354
- Args:
355
- token_ids_0 (`List[int]`):
356
- List of IDs.
357
- token_ids_1 (`List[int]`, *optional*):
358
- Optional second list of IDs for sequence pairs.
359
- already_has_special_tokens (`bool`, *optional*, defaults to `False`):
360
- Whether or not the token list is already formatted with special tokens for the model.
361
- Returns:
362
- `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
363
- """
364
-
365
- if already_has_special_tokens:
366
- return self._special_token_mask(token_ids_0)
367
- elif token_ids_1 is None:
368
- return self._special_token_mask(token_ids_0) + [self.eos_token_id]
369
- else:
370
- return self._special_token_mask(token_ids_0 +
371
- token_ids_1) + [self.eos_token_id]
372
-
373
- def num_special_tokens_to_add(self, pair=False):
374
- """Just EOS"""
375
- return 1
376
-
377
- def save_vocabulary(self,
378
- save_directory: str,
379
- filename_prefix: Optional[str] = None) -> Tuple[str]:
380
- index = 0
381
- if os.path.isdir(save_directory):
382
- vocab_file = os.path.join(
383
- save_directory,
384
- (filename_prefix + "-" if filename_prefix else "") +
385
- VOCAB_FILES_NAMES["vocab_file"])
386
- else:
387
- vocab_file = (filename_prefix +
388
- "-" if filename_prefix else "") + save_directory
389
- with open(vocab_file, "w", encoding="utf-8") as writer:
390
- for token, token_index in sorted(self.vocab.items(),
391
- key=lambda kv: kv[1]):
392
- if index != token_index:
393
- logger.warning(
394
- f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
395
- " Please check that the vocabulary is not corrupted!")
396
- index = token_index
397
- writer.write(token + "\n")
398
- index += 1
399
- return (vocab_file, )
400
-
401
-
402
- class BasicTokenizer(object):
403
- """
404
- Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
405
- Args:
406
- do_lower_case (`bool`, *optional*, defaults to `True`):
407
- Whether or not to lowercase the input when tokenizing.
408
- never_split (`Iterable`, *optional*):
409
- Collection of tokens which will never be split during tokenization. Only has an effect when
410
- `do_basic_tokenize=True`
411
- tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
412
- Whether or not to tokenize Chinese characters.
413
- This should likely be deactivated for Japanese (see this
414
- [issue](https://github.com/huggingface/transformers/issues/328)).
415
- strip_accents: (`bool`, *optional*):
416
- Whether or not to strip all accents. If this option is not specified, then it will be determined by the
417
- value for `lowercase` (as in the original BERT).
418
- """
419
-
420
- def __init__(self,
421
- do_lower_case=True,
422
- never_split=None,
423
- tokenize_chinese_chars=True,
424
- strip_accents=None):
425
- if never_split is None:
426
- never_split = []
427
- self.do_lower_case = do_lower_case
428
- self.never_split = set(never_split)
429
- self.tokenize_chinese_chars = tokenize_chinese_chars
430
- self.strip_accents = strip_accents
431
-
432
- def tokenize(self, text, never_split=None):
433
- """
434
- Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
435
- WordPieceTokenizer.
436
- Args:
437
- never_split (`List[str]`, *optional*)
438
- Kept for backward compatibility purposes. Now implemented directly at the base class level (see
439
- [`PreTrainedTokenizer.tokenize`]) List of token not to split.
440
- """
441
- # union() returns a new set by concatenating the two sets.
442
- never_split = self.never_split.union(
443
- set(never_split)) if never_split else self.never_split
444
- text = self._clean_text(text)
445
-
446
- # This was added on November 1st, 2018 for the multilingual and Chinese
447
- # models. This is also applied to the English models now, but it doesn't
448
- # matter since the English models were not trained on any Chinese data
449
- # and generally don't have any Chinese data in them (there are Chinese
450
- # characters in the vocabulary because Wikipedia does have some Chinese
451
- # words in the English Wikipedia.).
452
- if self.tokenize_chinese_chars:
453
- text = self._tokenize_chinese_chars(text)
454
- orig_tokens = whitespace_tokenize(text)
455
- split_tokens = []
456
- for token in orig_tokens:
457
- if token not in never_split:
458
- if self.do_lower_case:
459
- token = token.lower()
460
- if self.strip_accents is not False:
461
- token = self._run_strip_accents(token)
462
- elif self.strip_accents:
463
- token = self._run_strip_accents(token)
464
- split_tokens.extend(self._run_split_on_punc(token, never_split))
465
-
466
- output_tokens = whitespace_tokenize(" ".join(split_tokens))
467
- return output_tokens
468
-
469
- def _run_strip_accents(self, text):
470
- """Strips accents from a piece of text."""
471
- text = unicodedata.normalize("NFD", text)
472
- output = []
473
- for char in text:
474
- cat = unicodedata.category(char)
475
- if cat == "Mn":
476
- continue
477
- output.append(char)
478
- return "".join(output)
479
-
480
- def _run_split_on_punc(self, text, never_split=None):
481
- """Splits punctuation on a piece of text."""
482
- if never_split is not None and text in never_split:
483
- return [text]
484
- chars = list(text)
485
- i = 0
486
- start_new_word = True
487
- output = []
488
- while i < len(chars):
489
- char = chars[i]
490
- if _is_punctuation(char):
491
- output.append([char])
492
- start_new_word = True
493
- else:
494
- if start_new_word:
495
- output.append([])
496
- start_new_word = False
497
- output[-1].append(char)
498
- i += 1
499
-
500
- return ["".join(x) for x in output]
501
-
502
- def _tokenize_chinese_chars(self, text):
503
- """Adds whitespace around any CJK character."""
504
- output = []
505
- for char in text:
506
- cp = ord(char)
507
- if self._is_chinese_char(cp):
508
- output.append(" ")
509
- output.append(char)
510
- output.append(" ")
511
- else:
512
- output.append(char)
513
- return "".join(output)
514
-
515
- def _is_chinese_char(self, cp):
516
- """Checks whether CP is the codepoint of a CJK character."""
517
- # This defines a "chinese character" as anything in the CJK Unicode block:
518
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
519
- #
520
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
521
- # despite its name. The modern Korean Hangul alphabet is a different block,
522
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
523
- # space-separated words, so they are not treated specially and handled
524
- # like the all of the other languages.
525
- if ((cp >= 0x4E00 and cp <= 0x9FFF)
526
- or (cp >= 0x3400 and cp <= 0x4DBF) #
527
- or (cp >= 0x20000 and cp <= 0x2A6DF) #
528
- or (cp >= 0x2A700 and cp <= 0x2B73F) #
529
- or (cp >= 0x2B740 and cp <= 0x2B81F) #
530
- or (cp >= 0x2B820 and cp <= 0x2CEAF) #
531
- or (cp >= 0xF900 and cp <= 0xFAFF)
532
- or (cp >= 0x2F800 and cp <= 0x2FA1F)): #
533
- return True
534
-
535
- return False
536
-
537
- def _clean_text(self, text):
538
- """Performs invalid character removal and whitespace cleanup on text."""
539
- output = []
540
- for char in text:
541
- cp = ord(char)
542
- if cp == 0 or cp == 0xFFFD or _is_control(char):
543
- continue
544
- if _is_whitespace(char):
545
- output.append(" ")
546
- else:
547
- output.append(char)
548
- return "".join(output)
549
-
550
-
551
- class WordpieceTokenizer(object):
552
- """Runs WordPiece tokenization."""
553
-
554
- def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
555
- self.vocab = vocab
556
- self.unk_token = unk_token
557
- self.max_input_chars_per_word = max_input_chars_per_word
558
-
559
- def tokenize(self, text):
560
- """
561
- Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
562
- tokenization using the given vocabulary.
563
- For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
564
- Args:
565
- text: A single token or whitespace separated tokens. This should have
566
- already been passed through *BasicTokenizer*.
567
- Returns:
568
- A list of wordpiece tokens.
569
- """
570
-
571
- output_tokens = []
572
- for token in whitespace_tokenize(text):
573
- chars = list(token)
574
- if len(chars) > self.max_input_chars_per_word:
575
- output_tokens.append(self.unk_token)
576
- continue
577
-
578
- is_bad = False
579
- start = 0
580
- sub_tokens = []
581
- while start < len(chars):
582
- end = len(chars)
583
- cur_substr = None
584
- while start < end:
585
- substr = "".join(chars[start:end])
586
- if start > 0:
587
- substr = "##" + substr
588
- if substr in self.vocab:
589
- cur_substr = substr
590
- break
591
- end -= 1
592
- if cur_substr is None:
593
- is_bad = True
594
- break
595
- sub_tokens.append(cur_substr)
596
- start = end
597
-
598
- if is_bad:
599
- output_tokens.append(self.unk_token)
600
- else:
601
- output_tokens.extend(sub_tokens)
602
- return output_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
embed.py CHANGED
@@ -1,4 +1,5 @@
1
  """
 
2
  This script turn list of string into embeddings.
3
  """
4
  from transformers import AutoTokenizer, TFAutoModel
 
1
  """
2
+ Based on transformers python API.
3
  This script turn list of string into embeddings.
4
  """
5
  from transformers import AutoTokenizer, TFAutoModel
lex_rank.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy, nltk
2
+ nltk.download('punkt')
3
+
4
+
5
+ from harvesttext import HarvestText
6
+ from lex_rank_util import degree_centrality_scores
7
+ from sentence_transformers import SentenceTransformer, util
8
+
9
+
10
+ class LexRank(object):
11
+ def __init__(self):
12
+ self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
13
+ self.ht = HarvestText()
14
+
15
+ def find_central(self, content: str, num=100):
16
+ if self.contains_chinese(content):
17
+ sentences = self.ht.cut_sentences(content)
18
+ else:
19
+ sentences = nltk.sent_tokenize(content)
20
+ embeddings = self.model.encode(sentences, convert_to_tensor=True).cpu()
21
+
22
+ # Compute the pair-wise cosine similarities
23
+ cos_scores = util.cos_sim(embeddings, embeddings).numpy()
24
+
25
+ # Compute the centrality for each sentence
26
+ centrality_scores = degree_centrality_scores(cos_scores, threshold=None)
27
+
28
+ # We argsort so that the first element is the sentence with the highest score
29
+ most_central_sentence_indices = numpy.argsort(-centrality_scores)
30
+
31
+ # num = 100
32
+ res = []
33
+ for index in most_central_sentence_indices:
34
+ if num < 0:
35
+ break
36
+ res.append(sentences[index])
37
+ num -= len(sentences[index])
38
+ return res
39
+
40
+ def contains_chinese(self, content: str):
41
+ for _char in content:
42
+ if '\u4e00' <= _char <= '\u9fa5':
43
+ return True
44
+ return False
LexRank.py → lex_rank_util.py RENAMED
File without changes
luotuo_util.py DELETED
@@ -1,82 +0,0 @@
1
- import torch
2
-
3
-
4
- class DeviceMap:
5
- __top_layer: str
6
- __device_map: dict
7
- __total_layers: int
8
- __layers: int
9
-
10
- def __init__(self, model=None):
11
- if model == "LLaMA":
12
- self.__top_layer = "model"
13
- self.__device_map = {
14
- "model.embed_tokens": 0,
15
- "model.norm": 0,
16
- "lm_head": 0,
17
- }
18
- self.__total_layers = 34
19
- self.__layers = 32
20
-
21
- elif model == "ChatGLM":
22
- self.__top_layer = "transformer"
23
- self.__device_map = {
24
- "transformer.word_embeddings": 0,
25
- "transformer.final_layernorm": 0,
26
- "lm_head": 0,
27
- }
28
- self.__total_layers = 30
29
- self.__layers = 28
30
-
31
- else:
32
- self.__top_layer = ""
33
- self.__device_map = {"": 0}
34
- self.__total_layers = 0
35
- self.__layers = 0
36
-
37
- def get(self):
38
- top_layer = self.__top_layer
39
- total_layers = self.__total_layers
40
- layers = self.__layers
41
- device_map = self.__device_map
42
-
43
- world_size = torch.cuda.device_count()
44
-
45
- free_gpu_mem = []
46
- for i in range(world_size):
47
- torch.cuda.set_device(i)
48
- free_gpu_mem.append(torch.cuda.mem_get_info()[0])
49
-
50
- min_id = min(enumerate(free_gpu_mem), key=lambda x: x[1])[0]
51
- max_id = max(enumerate(free_gpu_mem), key=lambda x: x[1])[0]
52
-
53
- totol_mem = sum(free_gpu_mem)
54
-
55
- world_layers = {
56
- id: int(round(total_layers * (mem / totol_mem)))
57
- for id, mem in enumerate(free_gpu_mem)
58
- }
59
-
60
- diff = total_layers - sum(world_layers.values())
61
- world_layers[max_id if diff > 0 else min_id] += diff
62
-
63
- cnt = total_layers - layers
64
- gpu_id = 0
65
-
66
- for i in range(layers):
67
- if cnt < world_layers[gpu_id]:
68
- cnt += 1
69
- else:
70
- gpu_id += 1
71
- cnt = 1
72
- device_map[f"{top_layer}.layers.{i}"] = gpu_id
73
-
74
- return device_map
75
-
76
- def peft(self):
77
- prefix = "base_model.model"
78
- device_map = self.get()
79
- perf_device_map = {"": 0}
80
- for k, v in device_map.items():
81
- perf_device_map[f"{prefix}.{k}"] = v
82
- return perf_device_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -16,4 +16,5 @@ datasets
16
  gradio
17
 
18
  sentence-transformers
19
- harvesttext
 
 
16
  gradio
17
 
18
  sentence-transformers
19
+ harvesttext
20
+ nltk
tmp/placeholder DELETED
File without changes