michael-guenther commited on
Commit
4b66519
1 Parent(s): 6170b43

add assertions and docs

Browse files
Files changed (1) hide show
  1. tokenizer.py +11 -2
tokenizer.py CHANGED
@@ -5,9 +5,15 @@ import warnings
5
 
6
 
7
  class JinaTokenizer(RobertaTokenizer):
8
- def __init__(self, *args, task_type_vocab_size=6, **kwargs):
 
 
 
 
 
 
 
9
  super().__init__(*args, **kwargs)
10
- self.task_type_vocab_size = task_type_vocab_size
11
 
12
  def __call__(self, *args, task_type=None, **kwargs):
13
  batch_encoding = super().__call__(*args, **kwargs)
@@ -50,6 +56,9 @@ class JinaTokenizer(RobertaTokenizer):
50
 
51
  def apply_task_type(m, x):
52
  x = torch.tensor(x)
 
 
 
53
  return m * x if len(x.shape) == 0 else m * x[:, None]
54
 
55
  if isinstance(batch_encoding['input_ids'], torch.Tensor):
 
5
 
6
 
7
  class JinaTokenizer(RobertaTokenizer):
8
+ def __init__(self, *args, **kwargs):
9
+ """
10
+ JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
11
+ the batch encoding.
12
+ The task_type_ids are used to pass instruction information to the model.
13
+ A task_type should either be an integer or a sequence of integers with the same
14
+ length as the batch size.
15
+ """
16
  super().__init__(*args, **kwargs)
 
17
 
18
  def __call__(self, *args, task_type=None, **kwargs):
19
  batch_encoding = super().__call__(*args, **kwargs)
 
56
 
57
  def apply_task_type(m, x):
58
  x = torch.tensor(x)
59
+ assert (
60
+ len(x.shape) == 0 or x.shape[0] == m.shape[0]
61
+ ), 'The shape of task_type does not match the size of the batch.'
62
  return m * x if len(x.shape) == 0 else m * x[:, None]
63
 
64
  if isinstance(batch_encoding['input_ids'], torch.Tensor):