Make _decode compatible with PreTrainedTokenizerBase
Browse filesCodeGen25Tokenizer correctly implements the `_decode` interface from `PreTrainedTokenizer` with the following signature
```py
def _decode(self, token_ids: List[int], ...,) -> str:
...
```
`PreTrainedTokenizer`, however, incorrectly shadows the `_decode` function of its base class `PreTrainedTokenizerBase`, which is defined like this:
```py
def _decode(self, token_ids: Union[int, List[int]], ...,) -> str:
...
```
As a result, CodeGen25Tokenizer cannot be used as a drop-in tokenizer in some codebases (like [TGI](https://github.com/huggingface/text-generation-inference)). This fix doesn't break any previous behaviour, but simply allows `decode` to also accept plain `int` values instead of only `list[int]`.
tokenization_codegen25.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
|
5 |
"""Tokenization classes for CodeGen2.5."""
|
6 |
|
7 |
-
from typing import List, Optional
|
8 |
|
9 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
10 |
from transformers.utils import logging
|
@@ -168,7 +168,9 @@ class CodeGen25Tokenizer(PreTrainedTokenizer):
|
|
168 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
169 |
return self.encoder.decode_single_token_bytes(index).decode("utf-8")
|
170 |
|
171 |
-
def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
|
|
|
|
|
172 |
if skip_special_tokens:
|
173 |
token_ids = [t for t in token_ids if t not in self.all_special_ids]
|
174 |
return self.encoder.decode(token_ids)
|
|
|
4 |
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
|
5 |
"""Tokenization classes for CodeGen2.5."""
|
6 |
|
7 |
+
from typing import List, Optional, Union
|
8 |
|
9 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
10 |
from transformers.utils import logging
|
|
|
168 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
169 |
return self.encoder.decode_single_token_bytes(index).decode("utf-8")
|
170 |
|
171 |
+
def _decode(self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, **kwargs):
|
172 |
+
if isinstance(token_ids, int):
|
173 |
+
token_ids = [token_ids]
|
174 |
if skip_special_tokens:
|
175 |
token_ids = [t for t in token_ids if t not in self.all_special_ids]
|
176 |
return self.encoder.decode(token_ids)
|