from transformers import PreTrainedTokenizerFast import numpy import torch class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast): def _batch_encode_plus(self, *args, **kwargs): outputs = super()._batch_encode_plus(*args, **kwargs) del outputs["token_type_ids"] # Get the input_ids to check for EOS tokens input_ids = outputs['input_ids'] # Function to check if sequence ends with EOS token def ends_with_eos(sequence): if len(sequence) == 0: return False return sequence[-1] == self.eos_token_id # Check for EOS tokens using input_ids only if isinstance(input_ids, torch.Tensor): last_token_is_eos = torch.tensor([ ends_with_eos(seq) for seq in input_ids ], dtype=torch.bool) if last_token_is_eos.all(): # If all sequences have EOS, just truncate all for key in ['input_ids', 'attention_mask']: outputs[key] = outputs[key][..., :-1] elif last_token_is_eos.any(): # Process each sequence individually batch_size = input_ids.shape[0] for i in range(batch_size): if last_token_is_eos[i]: for key in ['input_ids', 'attention_mask']: # Remove last token and add padding at start for this sequence truncated = outputs[key][i, :-1] outputs[key][i] = torch.cat([ torch.zeros_like(truncated[:1]), truncated ]) elif isinstance(input_ids, numpy.ndarray): last_token_is_eos = numpy.array([ ends_with_eos(seq) for seq in input_ids ], dtype=bool) if last_token_is_eos.all(): # If all sequences have EOS, just truncate all for key in ['input_ids', 'attention_mask']: outputs[key] = outputs[key][..., :-1] elif last_token_is_eos.any(): batch_size = input_ids.shape[0] for i in range(batch_size): if last_token_is_eos[i]: for key in ['input_ids', 'attention_mask']: # Remove last token and add padding at start for this sequence truncated = outputs[key][i, :-1] outputs[key][i] = numpy.concatenate([ numpy.zeros_like(truncated[:1]), truncated ]) elif isinstance(input_ids, list): last_token_is_eos = [ends_with_eos(seq) for seq in input_ids] if all(last_token_is_eos): # If all sequences have EOS, just truncate all for key in ['input_ids', 'attention_mask']: outputs[key] = [sequence[:-1] for sequence in outputs[key]] elif any(last_token_is_eos): for key in ['input_ids', 'attention_mask']: outputs[key] = [ [0] + sequence[:-1] if is_eos else sequence for sequence, is_eos in zip(outputs[key], last_token_is_eos) ] return outputs # Register the class from transformers import AutoTokenizer AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)