hellopahe commited on
Commit
1870c14
1 Parent(s): 43a6d14
Files changed (4) hide show
  1. app.py +65 -0
  2. article_extractor/tokenizers_pegasus.py +602 -0
  3. requirements.txt +7 -0
  4. util.py +82 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script summarizes from a paragraph(normally 512 length).
3
+ """
4
+
5
+ import torch
6
+ import gradio as gr
7
+
8
+ from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline
9
+ from article_extractor.tokenizers_pegasus import PegasusTokenizer
10
+
11
+
12
+ class SummaryExtractor(object):
13
+ def __init__(self):
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.model = PegasusForConditionalGeneration.from_pretrained('IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese').to(self.device)
16
+ self.tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese")
17
+
18
+
19
+ self.text2text_genr = Text2TextGenerationPipeline(self.model, self.tokenizer, device=self.device)
20
+
21
+ def extract(self, content: str, min=20, max=30) -> str:
22
+ # inputs = self.tokenizer(content, max_length=512, return_tensors='pt').to(device=self.device)
23
+ # summary_ids = self.model.generate(inputs['input_ids'])
24
+ # return self.tokenizer.batch_decode(summary_ids,
25
+ # skip_special_tokens=True,
26
+ # clean_up_tokenization_spaces=False)[0]
27
+ return str(self.text2text_genr(content, do_sample=False, min_length=min, max_length=max, num_return_sequences=3)[0]["generated_text"])
28
+
29
+
30
+ t_randeng = SummaryExtractor()
31
+
32
+
33
+ def randeng_extract(content):
34
+ return t_randeng.extract(content)
35
+
36
+ def similarity_check(query: str, doc: str):
37
+ return "similarity result"
38
+
39
+ with gr.Blocks() as app:
40
+ gr.Markdown("从下面的标签选择不同的摘要模型, 在左侧输入原文")
41
+ # with gr.Tab("CamelBell-Chinese-LoRA"):
42
+ # text_input = gr.Textbox()
43
+ # text_output = gr.Textbox()
44
+ # text_button = gr.Button("生成摘要")
45
+ with gr.Tab("Randeng-Pegasus-523M"):
46
+ text_input_1 = gr.Textbox()
47
+ text_output_1 = gr.Textbox()
48
+ text_button_1 = gr.Button("生成摘要")
49
+ # with gr.Tab("Flip Image"):
50
+ # with gr.Row():
51
+ # image_input = gr.Image()
52
+ # image_output = gr.Image()
53
+ # image_button = gr.Button("Flip")
54
+ with gr.Tab("相似度检测"):
55
+ with gr.Row():
56
+ text_input_query = gr.Textbox()
57
+ text_input_doc = gr.Textbox()
58
+ text_button_similarity = gr.Button("对比相似度")
59
+ text_output_similarity = gr.Textbox()
60
+
61
+ # text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
62
+ text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
63
+ text_button_similarity.click(similarity_check, inputs=text_input_query, outputs=text_input_doc)
64
+
65
+ app.launch()
article_extractor/tokenizers_pegasus.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../')
3
+ from article_extractor.data_utils import (
4
+ _is_control,
5
+ _is_punctuation,
6
+ _is_whitespace,
7
+ _is_chinese_char)
8
+ from transformers import PreTrainedTokenizer
9
+ from transformers import logging
10
+ from typing import List, Optional, Tuple, Union
11
+ import collections
12
+ import os
13
+ import unicodedata
14
+ import re
15
+ import jieba
16
+ import sys
17
+
18
+ # 提取摘要逻辑实现
19
+
20
+ # sys.path.append("../../../../")
21
+
22
+ jieba.dt.tmp_dir = os.path.expanduser(
23
+ "../tmp/")
24
+ # jieba.enable_parallel(8)
25
+ jieba.initialize()
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+
32
+ def load_vocab(vocab_file):
33
+ """Loads a vocabulary file into a dictionary."""
34
+ vocab = collections.OrderedDict()
35
+ with open(vocab_file, "r", encoding="utf-8") as reader:
36
+ tokens = reader.readlines()
37
+ for index, token in enumerate(tokens):
38
+ token = token.rstrip("\n")
39
+ vocab[token] = index
40
+ return vocab
41
+
42
+
43
+ def whitespace_tokenize(text):
44
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
45
+ text = text.strip()
46
+ if not text:
47
+ return []
48
+ tokens = text.split()
49
+ return tokens
50
+
51
+
52
+ class PegasusTokenizer(PreTrainedTokenizer):
53
+ # copy from BertTokenizer
54
+ r"""
55
+ Construct a Pegasus tokenizer. Based on WordPiece.
56
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
57
+ this superclass for more information regarding those methods.
58
+ Args:
59
+ vocab_file (`str`):
60
+ File containing the vocabulary.
61
+ do_lower_case (`bool`, *optional*, defaults to `True`):
62
+ Whether or not to lowercase the input when tokenizing.
63
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
64
+ Whether or not to do basic tokenization before WordPiece.
65
+ never_split (`Iterable`, *optional*):
66
+ Collection of tokens which will never be split during tokenization. Only has an effect when
67
+ `do_basic_tokenize=True`
68
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
69
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
70
+ token instead.
71
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
72
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
73
+ sequence classification or for a text and a question for question answering. It is also used as the last
74
+ token of a sequence built with special tokens.
75
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
76
+ The token used for padding, for example when batching sequences of different lengths.
77
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
78
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
79
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
80
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
81
+ The token used for masking values. This is the token used when training this model with masked language
82
+ modeling. This is the token which the model will try to predict.
83
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
84
+ Whether or not to tokenize Chinese characters.
85
+ This should likely be deactivated for Japanese (see this
86
+ [issue](https://github.com/huggingface/transformers/issues/328)).
87
+ strip_accents (`bool`, *optional*):
88
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
89
+ value for `lowercase` (as in the original BERT).
90
+ """
91
+
92
+ vocab_files_names = VOCAB_FILES_NAMES
93
+ model_input_names = ["input_ids", "attention_mask"]
94
+
95
+ # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
96
+ # pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
97
+ # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+
99
+ def __init__(self,
100
+ vocab_file,
101
+ do_lower_case=True,
102
+ do_basic_tokenize=True,
103
+ never_split=None,
104
+ pad_token="<pad>",
105
+ eos_token="</s>",
106
+ unk_token="<unk>",
107
+ mask_token="<mask_2>",
108
+ mask_token_sent="<mask_1>",
109
+ additional_special_tokens=None,
110
+ sep_token="[SEP]",
111
+ cls_token="[CLS]",
112
+ tokenize_chinese_chars=True,
113
+ strip_accents=None,
114
+ offset=100,
115
+ pre_tokenizer=lambda x: jieba.cut(x, HMM=False),
116
+ **kwargs):
117
+ self.offset = offset
118
+
119
+ if additional_special_tokens is not None:
120
+ if not isinstance(additional_special_tokens, list):
121
+ raise TypeError(
122
+ f"additional_special_tokens should be of type {type(list)}, \
123
+ but is {type(additional_special_tokens)}"
124
+ )
125
+
126
+ additional_special_tokens_extended = (
127
+ ([mask_token_sent] + additional_special_tokens)
128
+ if mask_token_sent not in additional_special_tokens
129
+ and mask_token_sent is not None else additional_special_tokens)
130
+
131
+ # fill additional tokens with ..., <unk_token_102> in case not all additional tokens are already taken
132
+ additional_special_tokens_extended += [
133
+ f"<unk_{i}>" for i in range(
134
+ len(additional_special_tokens_extended), self.offset - 1)
135
+ ]
136
+
137
+ if len(set(additional_special_tokens_extended)) != len(
138
+ additional_special_tokens_extended):
139
+ raise ValueError(
140
+ f"Please make sure that the provided additional_special_tokens \
141
+ do not contain an incorrectly shifted list of <unk_x> tokens. \
142
+ Found {additional_special_tokens_extended}."
143
+ )
144
+ additional_special_tokens = additional_special_tokens_extended
145
+ else:
146
+ additional_special_tokens = [
147
+ mask_token_sent
148
+ ] if mask_token_sent is not None else []
149
+ # additional_special_tokens += [f"<unk_{i}>" for i in range(3, self.offset)]
150
+
151
+ # print("additional_special_tokens: ", additional_special_tokens)
152
+
153
+ if not os.path.isfile(vocab_file):
154
+ raise ValueError(
155
+ f"Can't find a vocabulary file at path '{vocab_file}'. \
156
+ To load the vocabulary from a Google pretrained "
157
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
158
+ )
159
+
160
+ super().__init__(
161
+ do_lower_case=do_lower_case,
162
+ do_basic_tokenize=do_basic_tokenize,
163
+ never_split=never_split,
164
+ unk_token=unk_token,
165
+ sep_token=sep_token,
166
+ pad_token=pad_token,
167
+ cls_token=cls_token,
168
+ mask_token=mask_token,
169
+ eos_token=eos_token,
170
+ tokenize_chinese_chars=tokenize_chinese_chars,
171
+ additional_special_tokens=additional_special_tokens,
172
+ strip_accents=strip_accents,
173
+ **kwargs,
174
+ )
175
+
176
+ self.pre_tokenizer = pre_tokenizer
177
+ self.mask_token_sent = mask_token_sent
178
+ self.vocab = load_vocab(vocab_file)
179
+
180
+ self.vocab[self.eos_token] = self.vocab.pop("[unused1]")
181
+ # self.vocab[self.eos_token] = self.vocab.pop("[unused2]")
182
+ self.vocab[self.pad_token] = self.vocab.pop("[PAD]")
183
+ self.vocab[self.unk_token] = self.vocab.pop("[UNK]")
184
+
185
+ if self.mask_token_sent is not None:
186
+ self.vocab[self.mask_token] = self.vocab.pop("[unused3]")
187
+ self.vocab[self.mask_token_sent] = self.vocab.pop("[unused2]")
188
+
189
+ self.ids_to_tokens = collections.OrderedDict([
190
+ (ids, tok) for tok, ids in self.vocab.items()
191
+ ])
192
+ self.do_basic_tokenize = do_basic_tokenize
193
+ if do_basic_tokenize:
194
+ self.basic_tokenizer = BasicTokenizer(
195
+ do_lower_case=do_lower_case,
196
+ never_split=never_split,
197
+ tokenize_chinese_chars=tokenize_chinese_chars,
198
+ strip_accents=strip_accents,
199
+ )
200
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
201
+ unk_token=self.unk_token)
202
+
203
+ @property
204
+ def do_lower_case(self):
205
+ return self.basic_tokenizer.do_lower_case
206
+
207
+ @property
208
+ def vocab_size(self):
209
+ return len(self.vocab)
210
+
211
+ def get_vocab(self):
212
+ return dict(self.vocab, **self.added_tokens_encoder)
213
+
214
+ def _tokenize(self, text):
215
+ split_tokens = []
216
+ # print("pegasus_tokenizer: ", text)
217
+ for text in self.pre_tokenizer(text):
218
+ if text in self.vocab:
219
+ split_tokens.append(text)
220
+ else:
221
+ if self.do_basic_tokenize:
222
+ for token in self.basic_tokenizer.tokenize(
223
+ text, never_split=self.all_special_tokens):
224
+
225
+ # If the token is part of the never_split set
226
+ if token in self.basic_tokenizer.never_split:
227
+ split_tokens.append(token)
228
+ else:
229
+ split_tokens += self.wordpiece_tokenizer.tokenize(
230
+ token)
231
+ else:
232
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
233
+ return split_tokens
234
+
235
+ def _convert_token_to_id(self, token):
236
+ """Converts a token (str) in an id using the vocab."""
237
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
238
+
239
+ def _convert_id_to_token(self, index):
240
+ """Converts an index (integer) in a token (str) using the vocab."""
241
+ return self.ids_to_tokens.get(index, self.unk_token)
242
+
243
+ @staticmethod
244
+ def _cjk_punctuation():
245
+ return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\
246
+ \uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\
247
+ \uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\
248
+ \u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\
249
+ \u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'
250
+
251
+ def convert_ids_to_tokens(
252
+ self,
253
+ ids: Union[int, List[int]],
254
+ skip_special_tokens: bool = False) -> Union[str, List[str]]:
255
+ """
256
+ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
257
+ added tokens.
258
+ Args:
259
+ ids (`int` or `List[int]`):
260
+ The token id (or token ids) to convert to tokens.
261
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
262
+ Whether or not to remove special tokens in the decoding.
263
+ Returns:
264
+ `str` or `List[str]`: The decoded token(s).
265
+ """
266
+ if isinstance(ids, int):
267
+ if ids in self.added_tokens_decoder:
268
+ return self.added_tokens_decoder[ids]
269
+ else:
270
+ return self._convert_id_to_token(ids)
271
+ tokens = []
272
+ for index in ids:
273
+ index = int(index)
274
+ if skip_special_tokens and index in self.all_special_ids and index != 2:
275
+ continue
276
+ if index in self.added_tokens_decoder:
277
+ tokens.append(self.added_tokens_decoder[index])
278
+ else:
279
+ tokens.append(self._convert_id_to_token(index))
280
+ return tokens
281
+
282
+ def convert_tokens_to_string(self, tokens):
283
+ """Converts a sequence of tokens (string) in a single string."""
284
+ # for token in
285
+ # tokens = tokens or self.ids_to_tokens(ids)
286
+ # tokens = [token for token in tokens if not self._is_special(token)]
287
+
288
+ text = ''
289
+ for i, token in enumerate(tokens):
290
+ if token[:2] == '##':
291
+ text += token[2:]
292
+ elif len(token) == 1 and _is_chinese_char(ord(token)):
293
+ text += token
294
+ elif len(token) == 1 and _is_punctuation(token):
295
+ text += token
296
+ text += ' '
297
+ elif i > 0 and _is_chinese_char(ord(text[-1])):
298
+ text += token
299
+ elif tokens == "</s>":
300
+ continue
301
+ else:
302
+ text += ' '
303
+ text += token
304
+
305
+ text = re.sub(' +', ' ', text)
306
+ text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
307
+ punctuation = re.sub(' +', '', self._cjk_punctuation()).strip() + '+-/={(<['
308
+ punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
309
+ punctuation_regex = '(%s) ' % punctuation_regex
310
+ text = re.sub(punctuation_regex, '\\1', text)
311
+ text = re.sub(r'(\d\.) (\d)', '\\1\\2', text)
312
+
313
+ return text.strip()
314
+ # out_string = " ".join(tokens).replace(" ##", "").strip()
315
+
316
+ def build_inputs_with_special_tokens(
317
+ self,
318
+ token_ids_0: List[int],
319
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
320
+ """
321
+ Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
322
+ and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:
323
+ - single sequence: `X </s>`
324
+ - pair of sequences: `A B </s>` (not intended use)
325
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
326
+ separator.
327
+ Args:
328
+ token_ids_0 (`List[int]`):
329
+ List of IDs to which the special tokens will be added.
330
+ token_ids_1 (`List[int]`, *optional*):
331
+ Optional second list of IDs for sequence pairs.
332
+ Returns:
333
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
334
+ """
335
+ if token_ids_1 is None:
336
+ return token_ids_0 + [self.eos_token_id]
337
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
338
+
339
+ def _special_token_mask(self, seq):
340
+ all_special_ids = set(
341
+ self.all_special_ids) # call it once instead of inside list comp
342
+ # all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
343
+
344
+ return [1 if x in all_special_ids else 0 for x in seq]
345
+
346
+ def get_special_tokens_mask(
347
+ self,
348
+ token_ids_0: List[int],
349
+ token_ids_1: Optional[List[int]] = None,
350
+ already_has_special_tokens: bool = False) -> List[int]:
351
+ """
352
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
353
+ special tokens using the tokenizer `prepare_for_model` method.
354
+ Args:
355
+ token_ids_0 (`List[int]`):
356
+ List of IDs.
357
+ token_ids_1 (`List[int]`, *optional*):
358
+ Optional second list of IDs for sequence pairs.
359
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
360
+ Whether or not the token list is already formatted with special tokens for the model.
361
+ Returns:
362
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
363
+ """
364
+
365
+ if already_has_special_tokens:
366
+ return self._special_token_mask(token_ids_0)
367
+ elif token_ids_1 is None:
368
+ return self._special_token_mask(token_ids_0) + [self.eos_token_id]
369
+ else:
370
+ return self._special_token_mask(token_ids_0 +
371
+ token_ids_1) + [self.eos_token_id]
372
+
373
+ def num_special_tokens_to_add(self, pair=False):
374
+ """Just EOS"""
375
+ return 1
376
+
377
+ def save_vocabulary(self,
378
+ save_directory: str,
379
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
380
+ index = 0
381
+ if os.path.isdir(save_directory):
382
+ vocab_file = os.path.join(
383
+ save_directory,
384
+ (filename_prefix + "-" if filename_prefix else "") +
385
+ VOCAB_FILES_NAMES["vocab_file"])
386
+ else:
387
+ vocab_file = (filename_prefix +
388
+ "-" if filename_prefix else "") + save_directory
389
+ with open(vocab_file, "w", encoding="utf-8") as writer:
390
+ for token, token_index in sorted(self.vocab.items(),
391
+ key=lambda kv: kv[1]):
392
+ if index != token_index:
393
+ logger.warning(
394
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
395
+ " Please check that the vocabulary is not corrupted!")
396
+ index = token_index
397
+ writer.write(token + "\n")
398
+ index += 1
399
+ return (vocab_file, )
400
+
401
+
402
+ class BasicTokenizer(object):
403
+ """
404
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
405
+ Args:
406
+ do_lower_case (`bool`, *optional*, defaults to `True`):
407
+ Whether or not to lowercase the input when tokenizing.
408
+ never_split (`Iterable`, *optional*):
409
+ Collection of tokens which will never be split during tokenization. Only has an effect when
410
+ `do_basic_tokenize=True`
411
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
412
+ Whether or not to tokenize Chinese characters.
413
+ This should likely be deactivated for Japanese (see this
414
+ [issue](https://github.com/huggingface/transformers/issues/328)).
415
+ strip_accents: (`bool`, *optional*):
416
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
417
+ value for `lowercase` (as in the original BERT).
418
+ """
419
+
420
+ def __init__(self,
421
+ do_lower_case=True,
422
+ never_split=None,
423
+ tokenize_chinese_chars=True,
424
+ strip_accents=None):
425
+ if never_split is None:
426
+ never_split = []
427
+ self.do_lower_case = do_lower_case
428
+ self.never_split = set(never_split)
429
+ self.tokenize_chinese_chars = tokenize_chinese_chars
430
+ self.strip_accents = strip_accents
431
+
432
+ def tokenize(self, text, never_split=None):
433
+ """
434
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
435
+ WordPieceTokenizer.
436
+ Args:
437
+ never_split (`List[str]`, *optional*)
438
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
439
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
440
+ """
441
+ # union() returns a new set by concatenating the two sets.
442
+ never_split = self.never_split.union(
443
+ set(never_split)) if never_split else self.never_split
444
+ text = self._clean_text(text)
445
+
446
+ # This was added on November 1st, 2018 for the multilingual and Chinese
447
+ # models. This is also applied to the English models now, but it doesn't
448
+ # matter since the English models were not trained on any Chinese data
449
+ # and generally don't have any Chinese data in them (there are Chinese
450
+ # characters in the vocabulary because Wikipedia does have some Chinese
451
+ # words in the English Wikipedia.).
452
+ if self.tokenize_chinese_chars:
453
+ text = self._tokenize_chinese_chars(text)
454
+ orig_tokens = whitespace_tokenize(text)
455
+ split_tokens = []
456
+ for token in orig_tokens:
457
+ if token not in never_split:
458
+ if self.do_lower_case:
459
+ token = token.lower()
460
+ if self.strip_accents is not False:
461
+ token = self._run_strip_accents(token)
462
+ elif self.strip_accents:
463
+ token = self._run_strip_accents(token)
464
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
465
+
466
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
467
+ return output_tokens
468
+
469
+ def _run_strip_accents(self, text):
470
+ """Strips accents from a piece of text."""
471
+ text = unicodedata.normalize("NFD", text)
472
+ output = []
473
+ for char in text:
474
+ cat = unicodedata.category(char)
475
+ if cat == "Mn":
476
+ continue
477
+ output.append(char)
478
+ return "".join(output)
479
+
480
+ def _run_split_on_punc(self, text, never_split=None):
481
+ """Splits punctuation on a piece of text."""
482
+ if never_split is not None and text in never_split:
483
+ return [text]
484
+ chars = list(text)
485
+ i = 0
486
+ start_new_word = True
487
+ output = []
488
+ while i < len(chars):
489
+ char = chars[i]
490
+ if _is_punctuation(char):
491
+ output.append([char])
492
+ start_new_word = True
493
+ else:
494
+ if start_new_word:
495
+ output.append([])
496
+ start_new_word = False
497
+ output[-1].append(char)
498
+ i += 1
499
+
500
+ return ["".join(x) for x in output]
501
+
502
+ def _tokenize_chinese_chars(self, text):
503
+ """Adds whitespace around any CJK character."""
504
+ output = []
505
+ for char in text:
506
+ cp = ord(char)
507
+ if self._is_chinese_char(cp):
508
+ output.append(" ")
509
+ output.append(char)
510
+ output.append(" ")
511
+ else:
512
+ output.append(char)
513
+ return "".join(output)
514
+
515
+ def _is_chinese_char(self, cp):
516
+ """Checks whether CP is the codepoint of a CJK character."""
517
+ # This defines a "chinese character" as anything in the CJK Unicode block:
518
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
519
+ #
520
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
521
+ # despite its name. The modern Korean Hangul alphabet is a different block,
522
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
523
+ # space-separated words, so they are not treated specially and handled
524
+ # like the all of the other languages.
525
+ if ((cp >= 0x4E00 and cp <= 0x9FFF)
526
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
527
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
528
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
529
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
530
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
531
+ or (cp >= 0xF900 and cp <= 0xFAFF)
532
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)): #
533
+ return True
534
+
535
+ return False
536
+
537
+ def _clean_text(self, text):
538
+ """Performs invalid character removal and whitespace cleanup on text."""
539
+ output = []
540
+ for char in text:
541
+ cp = ord(char)
542
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
543
+ continue
544
+ if _is_whitespace(char):
545
+ output.append(" ")
546
+ else:
547
+ output.append(char)
548
+ return "".join(output)
549
+
550
+
551
+ class WordpieceTokenizer(object):
552
+ """Runs WordPiece tokenization."""
553
+
554
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
555
+ self.vocab = vocab
556
+ self.unk_token = unk_token
557
+ self.max_input_chars_per_word = max_input_chars_per_word
558
+
559
+ def tokenize(self, text):
560
+ """
561
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
562
+ tokenization using the given vocabulary.
563
+ For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
564
+ Args:
565
+ text: A single token or whitespace separated tokens. This should have
566
+ already been passed through *BasicTokenizer*.
567
+ Returns:
568
+ A list of wordpiece tokens.
569
+ """
570
+
571
+ output_tokens = []
572
+ for token in whitespace_tokenize(text):
573
+ chars = list(token)
574
+ if len(chars) > self.max_input_chars_per_word:
575
+ output_tokens.append(self.unk_token)
576
+ continue
577
+
578
+ is_bad = False
579
+ start = 0
580
+ sub_tokens = []
581
+ while start < len(chars):
582
+ end = len(chars)
583
+ cur_substr = None
584
+ while start < end:
585
+ substr = "".join(chars[start:end])
586
+ if start > 0:
587
+ substr = "##" + substr
588
+ if substr in self.vocab:
589
+ cur_substr = substr
590
+ break
591
+ end -= 1
592
+ if cur_substr is None:
593
+ is_bad = True
594
+ break
595
+ sub_tokens.append(cur_substr)
596
+ start = end
597
+
598
+ if is_bad:
599
+ output_tokens.append(self.unk_token)
600
+ else:
601
+ output_tokens.extend(sub_tokens)
602
+ return output_tokens
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ tensorflow
3
+
4
+ transformers
5
+ peft
6
+ pandas
7
+ sentence-transformers
util.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class DeviceMap:
5
+ __top_layer: str
6
+ __device_map: dict
7
+ __total_layers: int
8
+ __layers: int
9
+
10
+ def __init__(self, model=None):
11
+ if model == "LLaMA":
12
+ self.__top_layer = "model"
13
+ self.__device_map = {
14
+ "model.embed_tokens": 0,
15
+ "model.norm": 0,
16
+ "lm_head": 0,
17
+ }
18
+ self.__total_layers = 34
19
+ self.__layers = 32
20
+
21
+ elif model == "ChatGLM":
22
+ self.__top_layer = "transformer"
23
+ self.__device_map = {
24
+ "transformer.word_embeddings": 0,
25
+ "transformer.final_layernorm": 0,
26
+ "lm_head": 0,
27
+ }
28
+ self.__total_layers = 30
29
+ self.__layers = 28
30
+
31
+ else:
32
+ self.__top_layer = ""
33
+ self.__device_map = {"": 0}
34
+ self.__total_layers = 0
35
+ self.__layers = 0
36
+
37
+ def get(self):
38
+ top_layer = self.__top_layer
39
+ total_layers = self.__total_layers
40
+ layers = self.__layers
41
+ device_map = self.__device_map
42
+
43
+ world_size = torch.cuda.device_count()
44
+
45
+ free_gpu_mem = []
46
+ for i in range(world_size):
47
+ torch.cuda.set_device(i)
48
+ free_gpu_mem.append(torch.cuda.mem_get_info()[0])
49
+
50
+ min_id = min(enumerate(free_gpu_mem), key=lambda x: x[1])[0]
51
+ max_id = max(enumerate(free_gpu_mem), key=lambda x: x[1])[0]
52
+
53
+ totol_mem = sum(free_gpu_mem)
54
+
55
+ world_layers = {
56
+ id: int(round(total_layers * (mem / totol_mem)))
57
+ for id, mem in enumerate(free_gpu_mem)
58
+ }
59
+
60
+ diff = total_layers - sum(world_layers.values())
61
+ world_layers[max_id if diff > 0 else min_id] += diff
62
+
63
+ cnt = total_layers - layers
64
+ gpu_id = 0
65
+
66
+ for i in range(layers):
67
+ if cnt < world_layers[gpu_id]:
68
+ cnt += 1
69
+ else:
70
+ gpu_id += 1
71
+ cnt = 1
72
+ device_map[f"{top_layer}.layers.{i}"] = gpu_id
73
+
74
+ return device_map
75
+
76
+ def peft(self):
77
+ prefix = "base_model.model"
78
+ device_map = self.get()
79
+ perf_device_map = {"": 0}
80
+ for k, v in device_map.items():
81
+ perf_device_map[f"{prefix}.{k}"] = v
82
+ return perf_device_map