michael-guenther
commited on
Commit
•
4b66519
1
Parent(s):
6170b43
add assertions and docs
Browse files- tokenizer.py +11 -2
tokenizer.py
CHANGED
@@ -5,9 +5,15 @@ import warnings
|
|
5 |
|
6 |
|
7 |
class JinaTokenizer(RobertaTokenizer):
|
8 |
-
def __init__(self, *args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
super().__init__(*args, **kwargs)
|
10 |
-
self.task_type_vocab_size = task_type_vocab_size
|
11 |
|
12 |
def __call__(self, *args, task_type=None, **kwargs):
|
13 |
batch_encoding = super().__call__(*args, **kwargs)
|
@@ -50,6 +56,9 @@ class JinaTokenizer(RobertaTokenizer):
|
|
50 |
|
51 |
def apply_task_type(m, x):
|
52 |
x = torch.tensor(x)
|
|
|
|
|
|
|
53 |
return m * x if len(x.shape) == 0 else m * x[:, None]
|
54 |
|
55 |
if isinstance(batch_encoding['input_ids'], torch.Tensor):
|
|
|
5 |
|
6 |
|
7 |
class JinaTokenizer(RobertaTokenizer):
|
8 |
+
def __init__(self, *args, **kwargs):
|
9 |
+
"""
|
10 |
+
JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
|
11 |
+
the batch encoding.
|
12 |
+
The task_type_ids are used to pass instruction information to the model.
|
13 |
+
A task_type should either be an integer or a sequence of integers with the same
|
14 |
+
length as the batch size.
|
15 |
+
"""
|
16 |
super().__init__(*args, **kwargs)
|
|
|
17 |
|
18 |
def __call__(self, *args, task_type=None, **kwargs):
|
19 |
batch_encoding = super().__call__(*args, **kwargs)
|
|
|
56 |
|
57 |
def apply_task_type(m, x):
|
58 |
x = torch.tensor(x)
|
59 |
+
assert (
|
60 |
+
len(x.shape) == 0 or x.shape[0] == m.shape[0]
|
61 |
+
), 'The shape of task_type does not match the size of the batch.'
|
62 |
return m * x if len(x.shape) == 0 else m * x[:, None]
|
63 |
|
64 |
if isinstance(batch_encoding['input_ids'], torch.Tensor):
|