File size: 1,529 Bytes
b8a6dde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from tokenizers import Tokenizer
from tokenizers.models import WordLevel, BPE
from tokenizers.trainers import WordLevelTrainer, BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, ByteLevel

from pathlib import Path


def get_all_sentences(ds, lang: str):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_local_tokenizer(config, ds, lang: str, tokenizer_type: str, force_build: bool = False) -> Tokenizer:
    tokenizer_path = Path(config['dataset']['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path) or force_build:
        if ds is None:
            raise ValueError("Cannot find local tokenizer, dataset given is None")
        
        if tokenizer_type == "WordLevel":
            tokenizer = Tokenizer(WordLevel(unk_token='<unk>'))
            tokenizer.pre_tokenizer = Whitespace()
            trainer = WordLevelTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2)
        elif tokenizer_type == "BPE":
            tokenizer = Tokenizer(BPE(unk_token='<unk>'))
            tokenizer.pre_tokenizer = Whitespace()
            trainer = BpeTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2)
        else:
            raise ValueError("Unsupported Tokenizer type")
        
        tokenizer.train_from_iterator(
            get_all_sentences(ds, lang), trainer=trainer
        )
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer