from typing import Tuple import torch from transformer import get_model, Transformer from config import load_config, get_weights_file_path from train import get_local_dataset_tokenizer from tokenizer import get_or_build_local_tokenizer from tokenizers import Tokenizer def load_model_tokenizer( config, device = torch.device('cpu'), ) -> Tuple[Transformer, Tokenizer, Tokenizer]: """ Loads a local model and tokenizer from a given config """ if config['model']['preload'] is None: raise ValueError('Unspecified preload model') src_tokenizer = get_or_build_local_tokenizer( config=config, ds=None, lang=config['dataset']['src_lang'], tokenizer_type=config['dataset']['src_tokenizer'] ) tgt_tokenizer = get_or_build_local_tokenizer( config=config, ds=None, lang=config['dataset']['tgt_lang'], tokenizer_type=config['dataset']['tgt_tokenizer'] ) model = get_model( config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size(), ).to(device) model_filename = get_weights_file_path(config, config['model']['preload']) state = torch.load(model_filename, map_location=device) model.load_state_dict(state['model_state_dict']) print(f'Finish loading model and tokenizers') return (model, src_tokenizer, tgt_tokenizer) if __name__ == '__main__': config = load_config(file_name='config/config_final.yaml') model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)