Files changed (1) hide show
  1. tokenizer.py +42 -15
tokenizer.py CHANGED
@@ -5,19 +5,26 @@ import warnings
5
 
6
 
7
  class JinaTokenizer(RobertaTokenizer):
8
- def __init__(self, *args, task_type_vocab_size=6, **kwargs):
 
 
 
 
 
 
 
9
  super().__init__(*args, **kwargs)
10
- self.task_type_vocab_size = task_type_vocab_size
11
 
12
  def __call__(self, *args, task_type=None, **kwargs):
13
  batch_encoding = super().__call__(*args, **kwargs)
14
- batch_encoding = BatchEncoding(
15
- {
16
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
17
- **batch_encoding,
18
- },
19
- tensor_type=kwargs.get('return_tensors'),
20
- )
 
21
  return batch_encoding
22
 
23
  def _batch_encode_plus(self, *args, task_type=None, **kwargs):
@@ -45,18 +52,38 @@ class JinaTokenizer(RobertaTokenizer):
45
  return batch_encoding
46
 
47
  @staticmethod
48
- def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
 
 
 
 
 
 
 
 
49
  if isinstance(batch_encoding['input_ids'], torch.Tensor):
50
  shape = batch_encoding['input_ids'].shape
51
- return torch.ones(shape, dtype=torch.long) * task_type
52
  else:
53
- shape = torch.tensor(batch_encoding['input_ids']).shape
 
 
 
 
 
 
 
 
54
  if isinstance(batch_encoding['input_ids'], list):
55
- return (torch.ones(shape, dtype=torch.long) * task_type).tolist()
 
 
56
  elif isinstance(batch_encoding['input_ids'], np.array):
57
- return (torch.ones(shape, dtype=torch.long) * task_type).numpy()
 
 
58
  else:
59
  warnings.warn(
60
  'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
61
  )
62
- return torch.ones(shape, dtype=torch.long) * task_type
 
5
 
6
 
7
  class JinaTokenizer(RobertaTokenizer):
8
+ def __init__(self, *args, **kwargs):
9
+ """
10
+ JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
11
+ the batch encoding.
12
+ The task_type_ids are used to pass instruction information to the model.
13
+ A task_type should either be an integer or a sequence of integers with the same
14
+ length as the batch size.
15
+ """
16
  super().__init__(*args, **kwargs)
 
17
 
18
  def __call__(self, *args, task_type=None, **kwargs):
19
  batch_encoding = super().__call__(*args, **kwargs)
20
+ if task_type is not None:
21
+ batch_encoding = BatchEncoding(
22
+ {
23
+ 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
24
+ **batch_encoding,
25
+ },
26
+ tensor_type=kwargs.get('return_tensors'),
27
+ )
28
  return batch_encoding
29
 
30
  def _batch_encode_plus(self, *args, task_type=None, **kwargs):
 
52
  return batch_encoding
53
 
54
  @staticmethod
55
+ def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
56
+
57
+ def apply_task_type(m, x):
58
+ x = torch.tensor(x)
59
+ assert (
60
+ len(x.shape) == 0 or x.shape[0] == m.shape[0]
61
+ ), 'The shape of task_type does not match the size of the batch.'
62
+ return m * x if len(x.shape) == 0 else m * x[:, None]
63
+
64
  if isinstance(batch_encoding['input_ids'], torch.Tensor):
65
  shape = batch_encoding['input_ids'].shape
66
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
67
  else:
68
+ try:
69
+ shape = torch.tensor(batch_encoding['input_ids']).shape
70
+ except:
71
+ raise ValueError(
72
+ "Unable to create tensor, you should probably "
73
+ "activate truncation and/or padding with "
74
+ "'padding=True' 'truncation=True' to have batched "
75
+ "tensors with the same length."
76
+ )
77
  if isinstance(batch_encoding['input_ids'], list):
78
+ return (
79
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
80
+ ).tolist()
81
  elif isinstance(batch_encoding['input_ids'], np.array):
82
+ return (
83
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
84
+ ).numpy()
85
  else:
86
  warnings.warn(
87
  'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
88
  )
89
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)