from transformers import PreTrainedTokenizerFast import numpy import torch class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast): def _batch_encode_plus(self, *args, **kwargs): outputs = super()._batch_encode_plus(*args, **kwargs) del outputs["token_type_ids"] for key in ['input_ids', 'attention_mask']: if isinstance(outputs[key], torch.Tensor): outputs[key] = outputs[key][..., :-1] elif isinstance(outputs[key], numpy.ndarray): outputs[key] = outputs[key][..., :-1] elif isinstance(outputs[key], list): outputs[key] = [sequence[:-1] for sequence in outputs[key]] return outputs # Register the class from transformers import AutoTokenizer AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)