|
from transformers.dynamic_module_utils import get_class_from_dynamic_module |
|
from transformers.tokenization_utils import AddedToken |
|
|
|
_codegen_revision = dict(pretrained_model_name_or_path="Salesforce/codegen25-7b-multi", |
|
revision="d4dc9dd90e8b23d5411e6d970e3a11e88dc5c2bc") |
|
|
|
CodeGen25Tokenizer = get_class_from_dynamic_module( |
|
"tokenization_codegen25.CodeGen25Tokenizer", **_codegen_revision) |
|
|
|
tiktoken_tokenizer = get_class_from_dynamic_module( |
|
"tokenization_codegen25.tiktoken_tokenizer", **_codegen_revision) |
|
|
|
|
|
class DeciCoderTokenizer(CodeGen25Tokenizer): |
|
def __init__( |
|
self, |
|
pad_token=None, |
|
eos_token="<|endoftext|>", |
|
add_eos_token=False, |
|
add_special_tokens=True, |
|
**kwargs, |
|
): |
|
self._tiktoken_kwargs = dict(base="gpt2", pad_token=pad_token, add_special=add_special_tokens) |
|
self.add_eos_token = add_eos_token |
|
self.encoder = tiktoken_tokenizer(**self._tiktoken_kwargs) |
|
pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token |
|
eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token |
|
super().__init__( |
|
pad_token=pad_token_added, |
|
eos_token=eos_token_added, |
|
add_eos_token=add_eos_token, |
|
add_special_tokens=add_special_tokens, |
|
**kwargs, |
|
) |
|
|
|
def _convert_id_to_token(self, index): |
|
""" bug fix in CodeGen25Tokenizer """ |
|
try: |
|
return super()._convert_id_to_token(index) |
|
except: |
|
return None |
|
|
|
def __getstate__(self): |
|
""" make the object picklable """ |
|
return {**self.__dict__, "encoder": None} |
|
|
|
def __setstate__(self, state): |
|
""" initialize tiktoken encoder after unpickling """ |
|
state["encoder"] = tiktoken_tokenizer(**state["_tiktoken_kwargs"]) |
|
self.__dict__ = state |
|
|
|
def save_pretrained(self, *args, **kwargs): |
|
""" |
|
add_special_tokens is not JSON serializable, which crashes save_pretrained(). |
|
Removing it from the tokenizer_config.json does not affect from_pretrained(). |
|
""" |
|
add_special_tokens = self.add_special_tokens |
|
self.add_special_tokens = True |
|
super().save_pretrained(*args, **kwargs) |
|
self.add_special_tokens = add_special_tokens |
|
|