|
import torch, json |
|
|
|
class CharTokenizer: |
|
def __init__(self, corpus=None, vocab=None): |
|
if vocab is not None: |
|
self.vocab = vocab |
|
elif corpus is not None: |
|
self.vocab = self._build_vocab(corpus) |
|
else: |
|
raise Exception("Either corpus or vocab has to be supplied") |
|
self.id2vocab = [char for char, index in sorted(self.vocab.items(), key=lambda item: item[1])] |
|
|
|
def _tokenize(self, text): |
|
return list(text) |
|
|
|
def __call__(self, prompt, text=None, add_eos_token=False): |
|
token_ids = [self.vocab.get(token, 0) for token in self._tokenize(prompt)] |
|
if text is not None: |
|
text_token_ids = [self.vocab.get(token, 0) for token in self._tokenize(text)] |
|
token_ids = token_ids + [self.vocab["<bos>"]] + text_token_ids |
|
if add_eos_token: |
|
token_ids = token_ids + [self.vocab["<eos>"]] |
|
input_ids_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) |
|
attention_masks = torch.ones_like(input_ids_tensor) |
|
return {"input_ids": input_ids_tensor, "attention_mask": attention_masks} |
|
|
|
def _build_vocab(self, corpus): |
|
vocab = {"<pad>": 0} |
|
for verse_lengths in range(3, 10): |
|
vocab[str(verse_lengths)] = len(vocab) |
|
for doc in corpus: |
|
chars = self._tokenize(doc) |
|
for char in chars: |
|
if char not in vocab: |
|
vocab[char] = len(vocab) |
|
vocab["<bos>"] = len(vocab) |
|
vocab["<eos>"] = len(vocab) |
|
return vocab |
|
|
|
def decode(self, token_ids): |
|
chars = [self.id2vocab[token_id] for token_id in token_ids.flatten().tolist()] |
|
filtered_chars = [char for char in chars if char not in ["<eos>", "<bos>", "<pad>"]] |
|
return "".join(filtered_chars) |
|
|
|
def save(self, filepath): |
|
with open(filepath, "w") as f: |
|
json.dump(self.vocab, f) |
|
|
|
@classmethod |
|
def load(cls, filepath): |
|
with open(filepath) as f: |
|
vocab = json.load(f) |
|
return cls(vocab=vocab) |
|
|