MikkoLipsanen's picture
Upload 5 files
375fd17 verified
raw
history blame
No virus
4.9 kB
import torch
import onnx
import onnxruntime
import os
import matplotlib.pyplot as plt
import numpy as np
import random
def set_seed(random_seed):
"""Function for setting random seed for the relevant libraries."""
np.random.seed(random_seed)
random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(random_seed)
print(f"Random seed set as {random_seed}")
def save_model(model, input_size, save_model_format, save_model_path, model_name, date):
"""Function for saving the model in .pth or .onnx format.
Code modified from
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html"""
if save_model_format == 'onnx':
onnx_model_path = os.path.join(save_model_path, model_name + '_' + date + '.onnx')
# Random batch size
batch_size = 1
# Random input to the model (with correct dimensions)
x = torch.randn(batch_size, 3, input_size, input_size, requires_grad=True)
model = model.to('cpu')
torch_out = model(x)
# Export the model
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
onnx_model_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
print('ONNX model saved to ', onnx_model_path)
# Test transformed model
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
print('ONNX model checked.')
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
onnx_session = onnxruntime.InferenceSession(onnx_model_path)
# compute ONNX Runtime output prediction
onnx_inputs = {onnx_session.get_inputs()[0].name: to_numpy(x)}
onnx_out = onnx_session.run(None, onnx_inputs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), onnx_out[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!\n")
else:
pytorch_model_path = os.path.join(save_model_path, 'densenet_' + date + '.pth')
torch.save(model, pytorch_model_path)
print('Pytorch model saved to ', pytorch_model_path)
def plot_metrics(hist_dict, results_folder, date):
"""Function for plotting the training and validation results."""
epochs = range(1, len(hist_dict['tr_loss'])+1)
plt.plot(epochs, hist_dict['tr_loss'], 'g', label='Training loss')
plt.plot(epochs, hist_dict['val_loss'], 'b', label='Validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig(results_folder + date + '_tr_val_loss.jpg', bbox_inches='tight')
plt.close()
plt.plot(epochs, hist_dict['tr_acc'], 'g', label='Training accuracy')
plt.plot(epochs, hist_dict['val_acc'], 'b', label='Validation accuracy')
plt.title('Training and Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig(results_folder + date + '_tr_val_acc.jpg', bbox_inches='tight')
plt.close()
plt.plot(epochs, hist_dict['tr_f1'], 'g', label='Training F1 score')
plt.plot(epochs, hist_dict['val_f1'], 'b', label='Validation F1 score')
plt.title('Training and Validation F1 score')
plt.xlabel('Epochs')
plt.ylabel('F1 score')
plt.legend()
plt.savefig(results_folder + date + '_tr_val_f1.jpg', bbox_inches='tight')
plt.close()
plt.plot(epochs, hist_dict['lr1'], 'g', label='Backbone learning rate')
plt.plot(epochs, hist_dict['lr2'], 'b', label='Classifier learning rate')
plt.title('Learning rate')
plt.xlabel('Epochs')
plt.ylabel('Learning rate')
plt.legend()
plt.savefig(results_folder + date + '_learning_rate.jpg', bbox_inches='tight')
plt.close()