Zymrael commited on
Commit
a35de04
1 Parent(s): 2c3da52

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +26 -12
tokenizer.py CHANGED
@@ -75,14 +75,25 @@ class HFAutoTokenizer:
75
 
76
  class ByteTokenizer(PreTrainedTokenizer):
77
  """UTF-8 Encoder."""
78
-
79
  def __init__(self):
80
- self.vocab_size = 512
81
- self.eod_id = 0
82
- self.eos_id = 0
83
- self.eos_token = 0
84
- self.eos_token_id = 0
85
- self.pad_id = 1
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def clamp(self, n):
88
  return max(32, min(n, self.vocab_size))
@@ -90,12 +101,15 @@ class ByteTokenizer(PreTrainedTokenizer):
90
  def decode_token(self, token: int):
91
  return str(chr(self.clamp(token)))
92
 
93
- def __call__(self, text: str, *args, **kwargs):
94
  ids = torch.tensor(self.tokenize(text), dtype=torch.long).unsqueeze(0)
95
- return {"input_ids": ids}
96
-
 
 
 
97
  def tokenize(self, text: str):
98
- return list(np.fromstring(text, dtype=np.uint8))
99
 
100
  def tokenize_batch(self, text_batch: Union[List[str], str]):
101
  if isinstance(text_batch, list):
@@ -109,7 +123,7 @@ class ByteTokenizer(PreTrainedTokenizer):
109
  def decode_batch(self, token_ids: Union[List[str], str]):
110
  if isinstance(token_ids, list):
111
  return [self.decode(s) for s in token_ids]
112
- # elif if tensor, convert to list first
113
  elif isinstance(token_ids, torch.Tensor):
114
  return [self.decode(s) for s in token_ids.tolist()]
115
  else:
 
75
 
76
  class ByteTokenizer(PreTrainedTokenizer):
77
  """UTF-8 Encoder."""
 
78
  def __init__(self):
79
+ super().__init__(
80
+ bos_token=self.decode_token(2),
81
+ eos_token=self.decode_token(0),
82
+ unk_token=self.decode_token(0),
83
+ pad_token=self.decode_token(1),
84
+ mask_token=self.decode_token(3),
85
+ )
86
+
87
+ @property
88
+ def vocab_size(self) -> int:
89
+ return 512
90
+
91
+ @classmethod
92
+ def from_pretrained(cls, *args, **kwargs):
93
+ return cls()
94
+
95
+ def get_vocab(self):
96
+ return {str(i): i for i in range(512)}
97
 
98
  def clamp(self, n):
99
  return max(32, min(n, self.vocab_size))
 
101
  def decode_token(self, token: int):
102
  return str(chr(self.clamp(token)))
103
 
104
+ def __call__(self, text: str, return_tensors: bool = False, *args, **kwargs):
105
  ids = torch.tensor(self.tokenize(text), dtype=torch.long).unsqueeze(0)
106
+ return {"input_ids": ids} if return_tensors == False else ids
107
+
108
+ def _tokenize(self, text: str):
109
+ return np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
110
+
111
  def tokenize(self, text: str):
112
+ return self._tokenize(text).tolist()
113
 
114
  def tokenize_batch(self, text_batch: Union[List[str], str]):
115
  if isinstance(text_batch, list):
 
123
  def decode_batch(self, token_ids: Union[List[str], str]):
124
  if isinstance(token_ids, list):
125
  return [self.decode(s) for s in token_ids]
126
+
127
  elif isinstance(token_ids, torch.Tensor):
128
  return [self.decode(s) for s in token_ids.tolist()]
129
  else: