oweller2
commited on
Commit
•
b54c050
1
Parent(s):
6e82f17
done
Browse files- tokenizer.py +11 -11
tokenizer.py
CHANGED
@@ -5,17 +5,17 @@ import torch
|
|
5 |
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
6 |
|
7 |
def _batch_encode_plus(self, *args, **kwargs):
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
|
20 |
# Register the class
|
21 |
from transformers import AutoTokenizer
|
|
|
5 |
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
6 |
|
7 |
def _batch_encode_plus(self, *args, **kwargs):
|
8 |
+
outputs = super()._batch_encode_plus(*args, **kwargs)
|
9 |
+
del outputs["token_type_ids"]
|
10 |
+
for key in ['input_ids', 'attention_mask']:
|
11 |
+
if isinstance(outputs[key], (list, numpy.ndarray, torch.Tensor)):
|
12 |
+
if isinstance(outputs[key], list):
|
13 |
+
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
|
14 |
+
elif isinstance(outputs[key], numpy.ndarray):
|
15 |
+
outputs[key] = numpy.array([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype)
|
16 |
+
elif isinstance(outputs[key], torch.Tensor):
|
17 |
+
outputs[key] = torch.tensor([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype, device=outputs[key].device)
|
18 |
+
return outputs
|
19 |
|
20 |
# Register the class
|
21 |
from transformers import AutoTokenizer
|