oweller2
commited on
Commit
•
b7a2cf0
1
Parent(s):
b54c050
done:
Browse files- tokenizer.py +6 -7
tokenizer.py
CHANGED
@@ -8,13 +8,12 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
|
8 |
outputs = super()._batch_encode_plus(*args, **kwargs)
|
9 |
del outputs["token_type_ids"]
|
10 |
for key in ['input_ids', 'attention_mask']:
|
11 |
-
if isinstance(outputs[key],
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
outputs[key] = torch.tensor([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype, device=outputs[key].device)
|
18 |
return outputs
|
19 |
|
20 |
# Register the class
|
|
|
8 |
outputs = super()._batch_encode_plus(*args, **kwargs)
|
9 |
del outputs["token_type_ids"]
|
10 |
for key in ['input_ids', 'attention_mask']:
|
11 |
+
if isinstance(outputs[key], torch.Tensor):
|
12 |
+
outputs[key] = outputs[key][..., :-1]
|
13 |
+
elif isinstance(outputs[key], numpy.ndarray):
|
14 |
+
outputs[key] = outputs[key][..., :-1]
|
15 |
+
elif isinstance(outputs[key], list):
|
16 |
+
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
|
|
|
17 |
return outputs
|
18 |
|
19 |
# Register the class
|