mini-sun-init-opt-tf-475m / tokenizer_make.py
finnstrom3693's picture
Update tokenizer_make.py
4cfb85c verified
# @title Model Tokenizer
from transformers import GPT2TokenizerFast
import os
class MiniSunTokenizer:
def __init__(self, vocab_file=None, merges_file=None):
if vocab_file:
self.tokenizer = GPT2TokenizerFast(vocab_file=vocab_file, merges_file=merges_file)
else:
self.tokenizer = GPT2TokenizerFast.from_pretrained('finnstrom3693/opt-125m-lss-en')
# Define special tokens for OPT
self.pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else '[PAD]'
self.unk_token = self.tokenizer.unk_token if self.tokenizer.unk_token else '[UNK]'
self.cls_token = self.tokenizer.bos_token if self.tokenizer.bos_token else '[CLS]'
self.eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else '[EOS]'
self.mask_token = self.tokenizer.mask_token if self.tokenizer.mask_token else '[MASK]'
self.sep_token = self.tokenizer.sep_token if self.tokenizer.sep_token else '[SEP]'
def encode(self, text, max_length=512, padding=True, truncation=True):
if isinstance(text, list):
return self._encode_batch(text, max_length, padding, truncation)
else:
return self._encode_single(text, max_length, padding, truncation)
def _encode_single(self, text, max_length=512, padding=True, truncation=True):
encoded = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
padding='max_length' if padding else False,
truncation=truncation,
return_attention_mask=True,
return_tensors='np'
)
return {
'input_ids': encoded['input_ids'],
'attention_mask': encoded['attention_mask']
}
def _encode_batch(self, texts, max_length=512, padding=True, truncation=True):
encoded_batch = self.tokenizer.batch_encode_plus(
texts,
add_special_tokens=True,
max_length=max_length,
padding='max_length' if padding else False,
truncation=truncation,
return_attention_mask=True,
return_tensors='np'
)
return {
'input_ids': encoded_batch['input_ids'],
'attention_mask': encoded_batch['attention_mask']
}
def decode(self, token_ids):
return self.tokenizer.decode(token_ids, skip_special_tokens=True)
def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
self.tokenizer.save_pretrained(save_directory)
def __call__(self, text, *args, **kwargs):
return self.encode(text, *args, **kwargs)
# Example usage of the tokenizer
tokenizer = MiniSunTokenizer()
# offline tokenizer
# tokenizer = MiniSunTokenizer(vocab_file='vocab.json', merges_file='merges.txt')