|
from typing import Tuple |
|
|
|
import torch |
|
from transformer import get_model, Transformer |
|
from config import load_config, get_weights_file_path |
|
from train import get_local_dataset_tokenizer |
|
from tokenizer import get_or_build_local_tokenizer |
|
|
|
from tokenizers import Tokenizer |
|
|
|
|
|
def load_model_tokenizer( |
|
config, |
|
device = torch.device('cpu'), |
|
) -> Tuple[Transformer, Tokenizer, Tokenizer]: |
|
""" |
|
Loads a local model and tokenizer from a given config |
|
""" |
|
if config['model']['preload'] is None: |
|
raise ValueError('Unspecified preload model') |
|
|
|
src_tokenizer = get_or_build_local_tokenizer( |
|
config=config, |
|
ds=None, |
|
lang=config['dataset']['src_lang'], |
|
tokenizer_type=config['dataset']['src_tokenizer'] |
|
) |
|
tgt_tokenizer = get_or_build_local_tokenizer( |
|
config=config, |
|
ds=None, |
|
lang=config['dataset']['tgt_lang'], |
|
tokenizer_type=config['dataset']['tgt_tokenizer'] |
|
) |
|
|
|
model = get_model( |
|
config, |
|
src_tokenizer.get_vocab_size(), |
|
tgt_tokenizer.get_vocab_size(), |
|
).to(device) |
|
|
|
model_filename = get_weights_file_path(config, config['model']['preload']) |
|
state = torch.load(model_filename, map_location=device) |
|
model.load_state_dict(state['model_state_dict']) |
|
|
|
print(f'Finish loading model and tokenizers') |
|
return (model, src_tokenizer, tgt_tokenizer) |
|
|
|
if __name__ == '__main__': |
|
config = load_config(file_name='config/config_final.yaml') |
|
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config) |
|
|