Vinno97 commited on
Commit
1d07fb3
1 Parent(s): 08bd09d

Make _decode compatible with PreTrainedTokenizerBase

Browse files

CodeGen25Tokenizer correctly implements the `_decode` interface from `PreTrainedTokenizer` with the following signature

```py
def _decode(self, token_ids: List[int], ...,) -> str:
...
```

`PreTrainedTokenizer`, however, incorrectly shadows the `_decode` function of its base class `PreTrainedTokenizerBase`, which is defined like this:

```py
def _decode(self, token_ids: Union[int, List[int]], ...,) -> str:
...
```

As a result, CodeGen25Tokenizer cannot be used as a drop-in tokenizer in some codebases (like [TGI](https://github.com/huggingface/text-generation-inference)). This fix doesn't break any previous behaviour, but simply allows `decode` to also accept plain `int` values instead of only `list[int]`.

Files changed (1) hide show
  1. tokenization_codegen25.py +4 -2
tokenization_codegen25.py CHANGED
@@ -4,7 +4,7 @@
4
  # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
5
  """Tokenization classes for CodeGen2.5."""
6
 
7
- from typing import List, Optional
8
 
9
  from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
10
  from transformers.utils import logging
@@ -168,7 +168,9 @@ class CodeGen25Tokenizer(PreTrainedTokenizer):
168
  """Converts an index (integer) in a token (str) using the vocab."""
169
  return self.encoder.decode_single_token_bytes(index).decode("utf-8")
170
 
171
- def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
 
 
172
  if skip_special_tokens:
173
  token_ids = [t for t in token_ids if t not in self.all_special_ids]
174
  return self.encoder.decode(token_ids)
 
4
  # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
5
  """Tokenization classes for CodeGen2.5."""
6
 
7
+ from typing import List, Optional, Union
8
 
9
  from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
10
  from transformers.utils import logging
 
168
  """Converts an index (integer) in a token (str) using the vocab."""
169
  return self.encoder.decode_single_token_bytes(index).decode("utf-8")
170
 
171
+ def _decode(self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, **kwargs):
172
+ if isinstance(token_ids, int):
173
+ token_ids = [token_ids]
174
  if skip_special_tokens:
175
  token_ids = [t for t in token_ids if t not in self.all_special_ids]
176
  return self.encoder.decode(token_ids)