Transformers
PyTorch
code
custom_code
Inference Endpoints
codesage commited on
Commit
d8ef66d
1 Parent(s): d4c8b2d

enable add_eos_token

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. tokenization_codesage.py +277 -0
config.json CHANGED
@@ -5,6 +5,7 @@
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "config_codesage.CodeSageConfig",
 
8
  "AutoModel": "modeling_codesage.CodeSageModel",
9
  "AutoModelForMaskedLM": "modeling_codesage.CodeSageForMaskedLM",
10
  "AutoModelForSequenceClassification": "modeling_codesage.CodeSageForSequenceClassification"
 
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "config_codesage.CodeSageConfig",
8
+ "AutoTokenizer": "tokenization_codesage.CodeSageTokenizer",
9
  "AutoModel": "modeling_codesage.CodeSageModel",
10
  "AutoModelForMaskedLM": "modeling_codesage.CodeSageForMaskedLM",
11
  "AutoModelForSequenceClassification": "modeling_codesage.CodeSageForSequenceClassification"
tokenization_codesage.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import List, Optional, Tuple
5
+
6
+ import regex as re
7
+
8
+ from transformers import AddedToken, PreTrainedTokenizer
9
+ import logging
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ VOCAB_FILES_NAMES = {
15
+ "vocab_file": "vocab.json",
16
+ "merges_file": "merges.txt",
17
+ }
18
+
19
+ # Taken from
20
+ # https://github.com/huggingface/transformers/blob/8aca43bdb3cb9a5020f6d57589d85679dc873b1c/src/transformers/models/gpt2/tokenization_gpt2.py#L62-L84
21
+ @lru_cache()
22
+ def bytes_to_unicode():
23
+ """
24
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
25
+ characters the bpe code barfs on.
26
+
27
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
28
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
29
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
30
+ tables between utf-8 bytes and unicode strings.
31
+ """
32
+ bs = (
33
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
34
+ )
35
+ cs = bs[:]
36
+ n = 0
37
+ for b in range(2**8):
38
+ if b not in bs:
39
+ bs.append(b)
40
+ cs.append(2**8 + n)
41
+ n += 1
42
+ cs = [chr(n) for n in cs]
43
+ return dict(zip(bs, cs))
44
+
45
+
46
+ def get_pairs(word):
47
+ """
48
+ Return set of symbol pairs in a word.
49
+
50
+ Word is represented as tuple of symbols (symbols being variable-length strings).
51
+ """
52
+ pairs = set()
53
+ prev_char = word[0]
54
+ for char in word[1:]:
55
+ pairs.add((prev_char, char))
56
+ prev_char = char
57
+ return pairs
58
+
59
+
60
+ class CodeSageTokenizer(PreTrainedTokenizer):
61
+ """A thin wrapper of the starcoder tokenizer.
62
+ See HuggingFace for further documentation on general tokenizer methods.
63
+ """
64
+
65
+ vocab_files_names = VOCAB_FILES_NAMES
66
+ model_input_names = ["input_ids", "attention_mask"]
67
+
68
+ def __init__(
69
+ self,
70
+ vocab_file,
71
+ merges_file,
72
+ errors="replace",
73
+ unk_token="<|endoftext|>",
74
+ bos_token="<|endoftext|>",
75
+ eos_token="<|endoftext|>",
76
+ pad_token=None,
77
+ add_prefix_space=False,
78
+ add_bos_token=False,
79
+ add_eos_token=False,
80
+ **kwargs,
81
+ ):
82
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
83
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
84
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
85
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
86
+
87
+ self.add_bos_token = add_bos_token
88
+ self.add_eos_token = add_eos_token
89
+
90
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
91
+ self.encoder = json.load(vocab_handle)
92
+ self.decoder = {v: k for k, v in self.encoder.items()}
93
+ self.errors = errors # how to handle errors in decoding
94
+ self.byte_encoder = bytes_to_unicode()
95
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
96
+ with open(merges_file, encoding="utf-8") as merges_handle:
97
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
98
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
99
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
100
+ self.cache = {}
101
+ self.add_prefix_space = add_prefix_space
102
+
103
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
104
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
105
+
106
+ super().__init__(
107
+ errors=errors,
108
+ unk_token=unk_token,
109
+ bos_token=bos_token,
110
+ eos_token=eos_token,
111
+ pad_token=pad_token,
112
+ add_prefix_space=add_prefix_space,
113
+ add_bos_token=add_bos_token,
114
+ add_eos_token=add_eos_token,
115
+ **kwargs,
116
+ )
117
+
118
+ @property
119
+ def vocab_size(self):
120
+ return len(self.encoder)
121
+
122
+ def get_vocab(self):
123
+ return dict(self.encoder, **self.added_tokens_encoder)
124
+
125
+ def bpe(self, token):
126
+ if token in self.cache:
127
+ return self.cache[token]
128
+ word = tuple(token)
129
+ pairs = get_pairs(word)
130
+
131
+ if not pairs:
132
+ return token
133
+
134
+ while True:
135
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
136
+ if bigram not in self.bpe_ranks:
137
+ break
138
+ first, second = bigram
139
+ new_word = []
140
+ i = 0
141
+ while i < len(word):
142
+ try:
143
+ j = word.index(first, i)
144
+ except ValueError:
145
+ new_word.extend(word[i:])
146
+ break
147
+ else:
148
+ new_word.extend(word[i:j])
149
+ i = j
150
+
151
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
152
+ new_word.append(first + second)
153
+ i += 2
154
+ else:
155
+ new_word.append(word[i])
156
+ i += 1
157
+ new_word = tuple(new_word)
158
+ word = new_word
159
+ if len(word) == 1:
160
+ break
161
+ else:
162
+ pairs = get_pairs(word)
163
+ word = " ".join(word)
164
+ self.cache[token] = word
165
+ return word
166
+
167
+ def build_inputs_with_special_tokens(
168
+ self,
169
+ token_ids_0: List[int],
170
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
171
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
172
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
173
+
174
+ output = bos_token_id + token_ids_0 + eos_token_id
175
+
176
+ if token_ids_1 is not None:
177
+ output = output + bos_token_id + token_ids_1 + eos_token_id
178
+
179
+ return output
180
+
181
+ def get_special_tokens_mask(
182
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
183
+ ) -> List[int]:
184
+ """
185
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
186
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
187
+
188
+ Args:
189
+ token_ids_0 (`List[int]`):
190
+ List of IDs.
191
+ token_ids_1 (`List[int]`, *optional*):
192
+ Optional second list of IDs for sequence pairs.
193
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
194
+ Whether or not the token list is already formatted with special tokens for the model.
195
+
196
+ Returns:
197
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
198
+ """
199
+ if already_has_special_tokens:
200
+ return super().get_special_tokens_mask(
201
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
202
+ )
203
+
204
+ if not self.add_bos_token:
205
+ return super().get_special_tokens_mask(
206
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
207
+ )
208
+
209
+ if token_ids_1 is None:
210
+ return [1] + ([0] * len(token_ids_0))
211
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
212
+
213
+ def _tokenize(self, text):
214
+ """Tokenize a string."""
215
+ bpe_tokens = []
216
+ for token in re.findall(self.pat, text):
217
+ token = "".join(
218
+ self.byte_encoder[b] for b in token.encode("utf-8")
219
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
220
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
221
+ return bpe_tokens
222
+
223
+ def _convert_token_to_id(self, token):
224
+ """Converts a token (str) in an id using the vocab."""
225
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
226
+
227
+ def _convert_id_to_token(self, index):
228
+ """Converts an index (integer) in a token (str) using the vocab."""
229
+ return self.decoder.get(index)
230
+
231
+ def convert_tokens_to_string(self, tokens):
232
+ """Converts a sequence of tokens (string) in a single string."""
233
+ text = "".join(tokens)
234
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
235
+ return text
236
+
237
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
238
+ if not os.path.isdir(save_directory):
239
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
240
+ return
241
+ vocab_file = os.path.join(
242
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
243
+ )
244
+ merge_file = os.path.join(
245
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
246
+ )
247
+
248
+ with open(vocab_file, "w", encoding="utf-8") as f:
249
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
250
+
251
+ index = 0
252
+ with open(merge_file, "w", encoding="utf-8") as writer:
253
+ writer.write("#version: 0.2\n")
254
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
255
+ if index != token_index:
256
+ logger.warning(
257
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
258
+ " Please check that the tokenizer is not corrupted!"
259
+ )
260
+ index = token_index
261
+ writer.write(" ".join(bpe_tokens) + "\n")
262
+ index += 1
263
+
264
+ return vocab_file, merge_file
265
+
266
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
267
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
268
+ if is_split_into_words or add_prefix_space:
269
+ text = " " + text
270
+ return (text, kwargs)
271
+
272
+ @property
273
+ def default_chat_template(self):
274
+ """
275
+ A simple chat template that ignores role information and just concatenates messages with EOS tokens.
276
+ """
277
+ return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"