|
|
|
from transformers import GPT2TokenizerFast |
|
import os |
|
|
|
class MiniSunTokenizer: |
|
def __init__(self, vocab_file=None, merges_file=None): |
|
if vocab_file: |
|
self.tokenizer = GPT2TokenizerFast(vocab_file=vocab_file, merges_file=merges_file) |
|
else: |
|
self.tokenizer = GPT2TokenizerFast.from_pretrained('finnstrom3693/opt-125m-lss-en') |
|
|
|
|
|
self.pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else '[PAD]' |
|
self.unk_token = self.tokenizer.unk_token if self.tokenizer.unk_token else '[UNK]' |
|
self.cls_token = self.tokenizer.bos_token if self.tokenizer.bos_token else '[CLS]' |
|
self.eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else '[EOS]' |
|
self.mask_token = self.tokenizer.mask_token if self.tokenizer.mask_token else '[MASK]' |
|
self.sep_token = self.tokenizer.sep_token if self.tokenizer.sep_token else '[SEP]' |
|
|
|
def encode(self, text, max_length=512, padding=True, truncation=True): |
|
if isinstance(text, list): |
|
return self._encode_batch(text, max_length, padding, truncation) |
|
else: |
|
return self._encode_single(text, max_length, padding, truncation) |
|
|
|
def _encode_single(self, text, max_length=512, padding=True, truncation=True): |
|
encoded = self.tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=max_length, |
|
padding='max_length' if padding else False, |
|
truncation=truncation, |
|
return_attention_mask=True, |
|
return_tensors='np' |
|
) |
|
return { |
|
'input_ids': encoded['input_ids'], |
|
'attention_mask': encoded['attention_mask'] |
|
} |
|
|
|
def _encode_batch(self, texts, max_length=512, padding=True, truncation=True): |
|
encoded_batch = self.tokenizer.batch_encode_plus( |
|
texts, |
|
add_special_tokens=True, |
|
max_length=max_length, |
|
padding='max_length' if padding else False, |
|
truncation=truncation, |
|
return_attention_mask=True, |
|
return_tensors='np' |
|
) |
|
return { |
|
'input_ids': encoded_batch['input_ids'], |
|
'attention_mask': encoded_batch['attention_mask'] |
|
} |
|
|
|
def decode(self, token_ids): |
|
return self.tokenizer.decode(token_ids, skip_special_tokens=True) |
|
|
|
def save_pretrained(self, save_directory): |
|
os.makedirs(save_directory, exist_ok=True) |
|
self.tokenizer.save_pretrained(save_directory) |
|
|
|
def __call__(self, text, *args, **kwargs): |
|
return self.encode(text, *args, **kwargs) |
|
|
|
|
|
tokenizer = MiniSunTokenizer() |
|
|
|
|
|
|