import torch import torch.nn as nn # Save and Load Functions. def save_checkpoint(save_path, model, valid_loss): if save_path == None: return state_dict = {'model_state_dict': model.state_dict(), 'valid_loss': valid_loss} torch.save(state_dict, save_path) print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path)) def load_checkpoint(load_path, model, device): if load_path == None: return state_dict = torch.load(load_path, map_location=device) print('DICT:', state_dict) print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path)) model.load_state_dict(state_dict['model_state_dict']) return state_dict['valid_loss'] def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list): if save_path == None: return state_dict = {'train_loss_list': train_loss_list, 'valid_loss_list': valid_loss_list, 'global_steps_list': global_steps_list} torch.save(state_dict, save_path) print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path)) def load_metrics(load_path, device): if load_path == None: return state_dict = torch.load(load_path, map_location=device) print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path)) return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']