File size: 1,529 Bytes
b8a6dde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from tokenizers import Tokenizer
from tokenizers.models import WordLevel, BPE
from tokenizers.trainers import WordLevelTrainer, BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, ByteLevel
from pathlib import Path
def get_all_sentences(ds, lang: str):
for item in ds:
yield item['translation'][lang]
def get_or_build_local_tokenizer(config, ds, lang: str, tokenizer_type: str, force_build: bool = False) -> Tokenizer:
tokenizer_path = Path(config['dataset']['tokenizer_file'].format(lang))
if not Path.exists(tokenizer_path) or force_build:
if ds is None:
raise ValueError("Cannot find local tokenizer, dataset given is None")
if tokenizer_type == "WordLevel":
tokenizer = Tokenizer(WordLevel(unk_token='<unk>'))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2)
elif tokenizer_type == "BPE":
tokenizer = Tokenizer(BPE(unk_token='<unk>'))
tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2)
else:
raise ValueError("Unsupported Tokenizer type")
tokenizer.train_from_iterator(
get_all_sentences(ds, lang), trainer=trainer
)
tokenizer.save(str(tokenizer_path))
else:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer |