test-flex-gpt / tokenizer.py
oweller2
done:
b7a2cf0
raw
history blame
869 Bytes
from transformers import PreTrainedTokenizerFast
import numpy
import torch
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def _batch_encode_plus(self, *args, **kwargs):
outputs = super()._batch_encode_plus(*args, **kwargs)
del outputs["token_type_ids"]
for key in ['input_ids', 'attention_mask']:
if isinstance(outputs[key], torch.Tensor):
outputs[key] = outputs[key][..., :-1]
elif isinstance(outputs[key], numpy.ndarray):
outputs[key] = outputs[key][..., :-1]
elif isinstance(outputs[key], list):
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
return outputs
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)