homemade_lo_vi / load_model.py
moiduy04's picture
Upload load_model.py
e8a4189
from typing import Tuple
import torch
from transformer import get_model, Transformer
from config import load_config, get_weights_file_path
from train import get_local_dataset_tokenizer
from tokenizer import get_or_build_local_tokenizer
from tokenizers import Tokenizer
def load_model_tokenizer(
config,
device = torch.device('cpu'),
) -> Tuple[Transformer, Tokenizer, Tokenizer]:
"""
Loads a local model and tokenizer from a given config
"""
if config['model']['preload'] is None:
raise ValueError('Unspecified preload model')
src_tokenizer = get_or_build_local_tokenizer(
config=config,
ds=None,
lang=config['dataset']['src_lang'],
tokenizer_type=config['dataset']['src_tokenizer']
)
tgt_tokenizer = get_or_build_local_tokenizer(
config=config,
ds=None,
lang=config['dataset']['tgt_lang'],
tokenizer_type=config['dataset']['tgt_tokenizer']
)
model = get_model(
config,
src_tokenizer.get_vocab_size(),
tgt_tokenizer.get_vocab_size(),
).to(device)
model_filename = get_weights_file_path(config, config['model']['preload'])
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])
print(f'Finish loading model and tokenizers')
return (model, src_tokenizer, tgt_tokenizer)
if __name__ == '__main__':
config = load_config(file_name='config/config_final.yaml')
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)