import torch from enum import IntEnum import numpy as np from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast import warnings def get_tokenizer(parent_class): class TokenizerClass(parent_class): class TaskTypes(IntEnum): NULL = (0,) QUERY = 1 DOCUMENT = 2 STS = 3 CLUSTERING = (4,) CLASSIFICATION = 5 def __init__(self, *args, **kwargs): """ This class dynamically extends a given tokenizer class from the HF Transformers library (RobertaTokenizer or RobertaTokenizerFast). The task_type_ids are used to pass instruction information to the model. A task_type should either be an integer or a sequence of integers with the same length as the batch size. """ super().__init__(*args, **kwargs) def __call__(self, *args, task_type: TaskTypes = None, **kwargs): batch_encoding = super().__call__(*args, **kwargs) if task_type is not None: batch_encoding = self._add_task_type_ids( batch_encoding, task_type, kwargs.get('return_tensors') ) return batch_encoding def _batch_encode_plus(self, *args, task_type: TaskTypes = None, **kwargs): batch_encoding = super()._batch_encode_plus(*args, **kwargs) if task_type is not None: batch_encoding = self._add_task_type_ids( batch_encoding, task_type, kwargs.get('return_tensors') ) return batch_encoding def _encode_plus(self, *args, task_type: TaskTypes = None, **kwargs): batch_encoding = super()._encode_plus(*args, **kwargs) if task_type is not None: batch_encoding = self._add_task_type_ids( batch_encoding, task_type, kwargs.get('return_tensors') ) return batch_encoding @classmethod def _add_task_type_ids( cls, batch_encoding: BatchEncoding, task_type: TaskTypes, tensor_type: str ): return BatchEncoding( { 'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type), **batch_encoding, }, tensor_type=tensor_type, ) @staticmethod def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: TaskTypes): def apply_task_type(m, x): x = torch.tensor(x) assert ( len(x.shape) == 0 or x.shape[0] == m.shape[0] ), 'The shape of task_type does not match the size of the batch.' return m * x if len(x.shape) == 0 else m * x[:, None] if isinstance(batch_encoding['input_ids'], torch.Tensor): shape = batch_encoding['input_ids'].shape return apply_task_type(torch.ones(shape, dtype=torch.long), task_type) else: try: shape = torch.tensor(batch_encoding['input_ids']).shape except: raise ValueError( "Unable to create tensor, you should probably " "activate truncation and/or padding with " "'padding=True' 'truncation=True' to have batched " "tensors with the same length." ) if isinstance(batch_encoding['input_ids'], list): return ( apply_task_type(torch.ones(shape, dtype=torch.long), task_type) ).tolist() elif isinstance(batch_encoding['input_ids'], np.array): return ( apply_task_type(torch.ones(shape, dtype=torch.long), task_type) ).numpy() else: warnings.warn( 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor' ) return apply_task_type( torch.ones(shape, dtype=torch.long), task_type ) return TokenizerClass JinaTokenizer = get_tokenizer(RobertaTokenizer) JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)