File size: 2,109 Bytes
48d79f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch, json

class CharTokenizer:
    def __init__(self, corpus=None, vocab=None):
        if vocab is not None:
            self.vocab = vocab
        elif corpus is not None:
            self.vocab = self._build_vocab(corpus)
        else:
            raise Exception("Either corpus or vocab has to be supplied")
        self.id2vocab = [char for char, index in sorted(self.vocab.items(), key=lambda item: item[1])]
        
    def _tokenize(self, text):
        return list(text)
        
    def __call__(self, prompt, text=None, add_eos_token=False):
        token_ids = [self.vocab.get(token, 0) for token in self._tokenize(prompt)]
        if text is not None:
            text_token_ids = [self.vocab.get(token, 0) for token in self._tokenize(text)]
            token_ids = token_ids + [self.vocab["<bos>"]] + text_token_ids
        if add_eos_token:
            token_ids = token_ids + [self.vocab["<eos>"]]
        input_ids_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
        attention_masks = torch.ones_like(input_ids_tensor)
        return {"input_ids": input_ids_tensor, "attention_mask": attention_masks}
        
    def _build_vocab(self, corpus):
        vocab = {"<pad>": 0}
        for verse_lengths in range(3, 10):
            vocab[str(verse_lengths)] = len(vocab)
        for doc in corpus:
            chars = self._tokenize(doc)
            for char in chars:
                if char not in vocab:
                    vocab[char] = len(vocab)
        vocab["<bos>"] = len(vocab)
        vocab["<eos>"] = len(vocab)
        return vocab
    
    def decode(self, token_ids):
        chars = [self.id2vocab[token_id] for token_id in token_ids.flatten().tolist()]
        filtered_chars = [char for char in chars if char not in ["<eos>", "<bos>", "<pad>"]]
        return "".join(filtered_chars)
    
    def save(self, filepath):
        with open(filepath, "w") as f:
            json.dump(self.vocab, f)
    
    @classmethod
    def load(cls, filepath):
        with open(filepath) as f:
            vocab = json.load(f)
        return cls(vocab=vocab)