poet / char_tokenizer.py
Sijun He
upload spaces
48d79f7
import torch, json
class CharTokenizer:
def __init__(self, corpus=None, vocab=None):
if vocab is not None:
self.vocab = vocab
elif corpus is not None:
self.vocab = self._build_vocab(corpus)
else:
raise Exception("Either corpus or vocab has to be supplied")
self.id2vocab = [char for char, index in sorted(self.vocab.items(), key=lambda item: item[1])]
def _tokenize(self, text):
return list(text)
def __call__(self, prompt, text=None, add_eos_token=False):
token_ids = [self.vocab.get(token, 0) for token in self._tokenize(prompt)]
if text is not None:
text_token_ids = [self.vocab.get(token, 0) for token in self._tokenize(text)]
token_ids = token_ids + [self.vocab["<bos>"]] + text_token_ids
if add_eos_token:
token_ids = token_ids + [self.vocab["<eos>"]]
input_ids_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
attention_masks = torch.ones_like(input_ids_tensor)
return {"input_ids": input_ids_tensor, "attention_mask": attention_masks}
def _build_vocab(self, corpus):
vocab = {"<pad>": 0}
for verse_lengths in range(3, 10):
vocab[str(verse_lengths)] = len(vocab)
for doc in corpus:
chars = self._tokenize(doc)
for char in chars:
if char not in vocab:
vocab[char] = len(vocab)
vocab["<bos>"] = len(vocab)
vocab["<eos>"] = len(vocab)
return vocab
def decode(self, token_ids):
chars = [self.id2vocab[token_id] for token_id in token_ids.flatten().tolist()]
filtered_chars = [char for char in chars if char not in ["<eos>", "<bos>", "<pad>"]]
return "".join(filtered_chars)
def save(self, filepath):
with open(filepath, "w") as f:
json.dump(self.vocab, f)
@classmethod
def load(cls, filepath):
with open(filepath) as f:
vocab = json.load(f)
return cls(vocab=vocab)