jupyterjazz commited on
Commit
9b5c148
1 Parent(s): 851184a

feat: support fast tokenizer

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. tokenizer.py +84 -77
tokenizer.py CHANGED
@@ -1,89 +1,96 @@
1
  import torch
2
  import numpy as np
3
- from transformers import RobertaTokenizer, BatchEncoding
4
  import warnings
5
 
6
 
7
- class JinaTokenizer(RobertaTokenizer):
8
- def __init__(self, *args, **kwargs):
9
- """
10
- JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
11
- the batch encoding.
12
- The task_type_ids are used to pass instruction information to the model.
13
- A task_type should either be an integer or a sequence of integers with the same
14
- length as the batch size.
15
- """
16
- super().__init__(*args, **kwargs)
 
17
 
18
- def __call__(self, *args, task_type=None, **kwargs):
19
- batch_encoding = super().__call__(*args, **kwargs)
20
- if task_type is not None:
21
- batch_encoding = BatchEncoding(
22
- {
23
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
24
- **batch_encoding,
25
- },
26
- tensor_type=kwargs.get('return_tensors'),
27
- )
28
- return batch_encoding
29
 
30
- def _batch_encode_plus(self, *args, task_type=None, **kwargs):
31
- batch_encoding = super()._batch_encode_plus(*args, **kwargs)
32
- if task_type is not None:
33
- batch_encoding = BatchEncoding(
34
- {
35
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
36
- **batch_encoding,
37
- },
38
- tensor_type=kwargs.get('return_tensors'),
39
- )
40
- return batch_encoding
41
 
42
- def _encode_plus(self, *args, task_type=None, **kwargs):
43
- batch_encoding = super()._encode_plus(*args, **kwargs)
44
- if task_type is not None:
45
- batch_encoding = BatchEncoding(
46
- {
47
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
48
- **batch_encoding,
49
- },
50
- tensor_type=kwargs.get('return_tensors'),
51
- )
52
- return batch_encoding
53
 
54
- @staticmethod
55
- def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
56
 
57
- def apply_task_type(m, x):
58
- x = torch.tensor(x)
59
- assert (
60
- len(x.shape) == 0 or x.shape[0] == m.shape[0]
61
- ), 'The shape of task_type does not match the size of the batch.'
62
- return m * x if len(x.shape) == 0 else m * x[:, None]
63
 
64
- if isinstance(batch_encoding['input_ids'], torch.Tensor):
65
- shape = batch_encoding['input_ids'].shape
66
- return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
67
- else:
68
- try:
69
- shape = torch.tensor(batch_encoding['input_ids']).shape
70
- except:
71
- raise ValueError(
72
- "Unable to create tensor, you should probably "
73
- "activate truncation and/or padding with "
74
- "'padding=True' 'truncation=True' to have batched "
75
- "tensors with the same length."
76
- )
77
- if isinstance(batch_encoding['input_ids'], list):
78
- return (
79
- apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
80
- ).tolist()
81
- elif isinstance(batch_encoding['input_ids'], np.array):
82
- return (
83
- apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
84
- ).numpy()
85
- else:
86
- warnings.warn(
87
- 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
88
- )
89
  return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
+ from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
4
  import warnings
5
 
6
 
7
+ def get_tokenizer(parent_class):
8
+ class TokenizerClass(parent_class):
9
+ def __init__(self, *args, **kwargs):
10
+ """
11
+ JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
12
+ the batch encoding.
13
+ The task_type_ids are used to pass instruction information to the model.
14
+ A task_type should either be an integer or a sequence of integers with the same
15
+ length as the batch size.
16
+ """
17
+ super().__init__(*args, **kwargs)
18
 
19
+ def __call__(self, *args, task_type=None, **kwargs):
20
+ batch_encoding = super().__call__(*args, **kwargs)
21
+ if task_type is not None:
22
+ batch_encoding = BatchEncoding(
23
+ {
24
+ 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
25
+ **batch_encoding,
26
+ },
27
+ tensor_type=kwargs.get('return_tensors'),
28
+ )
29
+ return batch_encoding
30
 
31
+ def _batch_encode_plus(self, *args, task_type=None, **kwargs):
32
+ batch_encoding = super()._batch_encode_plus(*args, **kwargs)
33
+ if task_type is not None:
34
+ batch_encoding = BatchEncoding(
35
+ {
36
+ 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
37
+ **batch_encoding,
38
+ },
39
+ tensor_type=kwargs.get('return_tensors'),
40
+ )
41
+ return batch_encoding
42
 
43
+ def _encode_plus(self, *args, task_type=None, **kwargs):
44
+ batch_encoding = super()._encode_plus(*args, **kwargs)
45
+ if task_type is not None:
46
+ batch_encoding = BatchEncoding(
47
+ {
48
+ 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
49
+ **batch_encoding,
50
+ },
51
+ tensor_type=kwargs.get('return_tensors'),
52
+ )
53
+ return batch_encoding
54
 
55
+ @staticmethod
56
+ def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
57
 
58
+ def apply_task_type(m, x):
59
+ x = torch.tensor(x)
60
+ assert (
61
+ len(x.shape) == 0 or x.shape[0] == m.shape[0]
62
+ ), 'The shape of task_type does not match the size of the batch.'
63
+ return m * x if len(x.shape) == 0 else m * x[:, None]
64
 
65
+ if isinstance(batch_encoding['input_ids'], torch.Tensor):
66
+ shape = batch_encoding['input_ids'].shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
68
+ else:
69
+ try:
70
+ shape = torch.tensor(batch_encoding['input_ids']).shape
71
+ except:
72
+ raise ValueError(
73
+ "Unable to create tensor, you should probably "
74
+ "activate truncation and/or padding with "
75
+ "'padding=True' 'truncation=True' to have batched "
76
+ "tensors with the same length."
77
+ )
78
+ if isinstance(batch_encoding['input_ids'], list):
79
+ return (
80
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
81
+ ).tolist()
82
+ elif isinstance(batch_encoding['input_ids'], np.array):
83
+ return (
84
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
85
+ ).numpy()
86
+ else:
87
+ warnings.warn(
88
+ 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
89
+ )
90
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
91
+
92
+ return TokenizerClass
93
+
94
+
95
+ JinaTokenizer = get_tokenizer(RobertaTokenizer)
96
+ JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)