|
from typing import Tuple |
|
|
|
import torch |
|
from transformer import get_model, Transformer |
|
from config import load_config, get_weights_file_path |
|
from train import get_local_dataset_tokenizer |
|
from tokenizer import get_or_build_local_tokenizer |
|
|
|
from tokenizers import Tokenizer |
|
|
|
|
|
def load_train_data_and_save_model(config, model_name): |
|
""" |
|
loads training data (model, optim, scheduler,...) and saves ONLY the model. |
|
""" |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f'Using device {device}') |
|
|
|
train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config) |
|
model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device) |
|
|
|
assert config['model']['preload'], 'where to preload model.' |
|
|
|
model_load_filename = get_weights_file_path(config, config['model']['preload']) |
|
print(f'Preloading model from train data in {model_load_filename}') |
|
state = torch.load(model_load_filename, map_location=device) |
|
|
|
model.load_state_dict(state['model_state_dict']) |
|
|
|
model_save_filename = get_weights_file_path(config, model_name) |
|
torch.save(model.state_dict(), model_save_filename) |
|
print(f'Model saved at {model_save_filename}') |
|
|
|
def load_model_tokenizer( |
|
config, |
|
device = torch.device('cpu'), |
|
logs: bool = True, |
|
) -> Tuple[Transformer, Tokenizer, Tokenizer]: |
|
""" |
|
Loads a local model and tokenizer from a given config |
|
""" |
|
if config['model']['preload'] is None: |
|
raise ValueError('Unspecified preload model') |
|
|
|
src_tokenizer = get_or_build_local_tokenizer( |
|
config=config, |
|
ds=None, |
|
lang=config['dataset']['src_lang'], |
|
tokenizer_type=config['dataset']['src_tokenizer'] |
|
) |
|
tgt_tokenizer = get_or_build_local_tokenizer( |
|
config=config, |
|
ds=None, |
|
lang=config['dataset']['tgt_lang'], |
|
tokenizer_type=config['dataset']['tgt_tokenizer'] |
|
) |
|
|
|
model = get_model( |
|
config, |
|
src_tokenizer.get_vocab_size(), |
|
tgt_tokenizer.get_vocab_size(), |
|
).to(device) |
|
|
|
model_filename = get_weights_file_path(config, config['model']['preload']) |
|
model.load_state_dict( |
|
torch.load(model_filename, map_location=device) |
|
) |
|
print(f'Finish loading model and tokenizers') |
|
return (model, src_tokenizer, tgt_tokenizer) |
|
|
|
if __name__ == '__main__': |
|
config = load_config(file_name='config_huge.yaml') |
|
load_train_data_and_save_model(config, 'huge') |
|
|