|
from transformers import PreTrainedTokenizerFast |
|
import numpy |
|
import torch |
|
|
|
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast): |
|
|
|
def _batch_encode_plus(self, *args, **kwargs): |
|
outputs = super()._batch_encode_plus(*args, **kwargs) |
|
del outputs["token_type_ids"] |
|
for key in ['input_ids', 'attention_mask']: |
|
if isinstance(outputs[key], torch.Tensor): |
|
outputs[key] = outputs[key][..., :-1] |
|
elif isinstance(outputs[key], numpy.ndarray): |
|
outputs[key] = outputs[key][..., :-1] |
|
elif isinstance(outputs[key], list): |
|
outputs[key] = [sequence[:-1] for sequence in outputs[key]] |
|
return outputs |
|
|
|
|
|
from transformers import AutoTokenizer |
|
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |