Update tokenizer.py
Browse files- tokenizer.py +26 -12
tokenizer.py
CHANGED
@@ -75,14 +75,25 @@ class HFAutoTokenizer:
|
|
75 |
|
76 |
class ByteTokenizer(PreTrainedTokenizer):
|
77 |
"""UTF-8 Encoder."""
|
78 |
-
|
79 |
def __init__(self):
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
def clamp(self, n):
|
88 |
return max(32, min(n, self.vocab_size))
|
@@ -90,12 +101,15 @@ class ByteTokenizer(PreTrainedTokenizer):
|
|
90 |
def decode_token(self, token: int):
|
91 |
return str(chr(self.clamp(token)))
|
92 |
|
93 |
-
def __call__(self, text: str, *args, **kwargs):
|
94 |
ids = torch.tensor(self.tokenize(text), dtype=torch.long).unsqueeze(0)
|
95 |
-
return {"input_ids": ids}
|
96 |
-
|
|
|
|
|
|
|
97 |
def tokenize(self, text: str):
|
98 |
-
return
|
99 |
|
100 |
def tokenize_batch(self, text_batch: Union[List[str], str]):
|
101 |
if isinstance(text_batch, list):
|
@@ -109,7 +123,7 @@ class ByteTokenizer(PreTrainedTokenizer):
|
|
109 |
def decode_batch(self, token_ids: Union[List[str], str]):
|
110 |
if isinstance(token_ids, list):
|
111 |
return [self.decode(s) for s in token_ids]
|
112 |
-
|
113 |
elif isinstance(token_ids, torch.Tensor):
|
114 |
return [self.decode(s) for s in token_ids.tolist()]
|
115 |
else:
|
|
|
75 |
|
76 |
class ByteTokenizer(PreTrainedTokenizer):
|
77 |
"""UTF-8 Encoder."""
|
|
|
78 |
def __init__(self):
|
79 |
+
super().__init__(
|
80 |
+
bos_token=self.decode_token(2),
|
81 |
+
eos_token=self.decode_token(0),
|
82 |
+
unk_token=self.decode_token(0),
|
83 |
+
pad_token=self.decode_token(1),
|
84 |
+
mask_token=self.decode_token(3),
|
85 |
+
)
|
86 |
+
|
87 |
+
@property
|
88 |
+
def vocab_size(self) -> int:
|
89 |
+
return 512
|
90 |
+
|
91 |
+
@classmethod
|
92 |
+
def from_pretrained(cls, *args, **kwargs):
|
93 |
+
return cls()
|
94 |
+
|
95 |
+
def get_vocab(self):
|
96 |
+
return {str(i): i for i in range(512)}
|
97 |
|
98 |
def clamp(self, n):
|
99 |
return max(32, min(n, self.vocab_size))
|
|
|
101 |
def decode_token(self, token: int):
|
102 |
return str(chr(self.clamp(token)))
|
103 |
|
104 |
+
def __call__(self, text: str, return_tensors: bool = False, *args, **kwargs):
|
105 |
ids = torch.tensor(self.tokenize(text), dtype=torch.long).unsqueeze(0)
|
106 |
+
return {"input_ids": ids} if return_tensors == False else ids
|
107 |
+
|
108 |
+
def _tokenize(self, text: str):
|
109 |
+
return np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
|
110 |
+
|
111 |
def tokenize(self, text: str):
|
112 |
+
return self._tokenize(text).tolist()
|
113 |
|
114 |
def tokenize_batch(self, text_batch: Union[List[str], str]):
|
115 |
if isinstance(text_batch, list):
|
|
|
123 |
def decode_batch(self, token_ids: Union[List[str], str]):
|
124 |
if isinstance(token_ids, list):
|
125 |
return [self.decode(s) for s in token_ids]
|
126 |
+
|
127 |
elif isinstance(token_ids, torch.Tensor):
|
128 |
return [self.decode(s) for s in token_ids.tolist()]
|
129 |
else:
|