|
import torch |
|
import onnx |
|
import onnxruntime |
|
import os |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import random |
|
|
|
def set_seed(random_seed): |
|
"""Function for setting random seed for the relevant libraries.""" |
|
np.random.seed(random_seed) |
|
random.seed(random_seed) |
|
torch.manual_seed(random_seed) |
|
torch.cuda.manual_seed(random_seed) |
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
os.environ["PYTHONHASHSEED"] = str(random_seed) |
|
print(f"Random seed set as {random_seed}") |
|
|
|
def save_model(model, input_size, save_model_format, save_model_path, model_name, date): |
|
"""Function for saving the model in .pth or .onnx format. |
|
Code modified from |
|
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html""" |
|
if save_model_format == 'onnx': |
|
onnx_model_path = os.path.join(save_model_path, model_name + '_' + date + '.onnx') |
|
|
|
batch_size = 1 |
|
|
|
x = torch.randn(batch_size, 3, input_size, input_size, requires_grad=True) |
|
model = model.to('cpu') |
|
torch_out = model(x) |
|
|
|
|
|
torch.onnx.export(model, |
|
x, |
|
onnx_model_path, |
|
export_params=True, |
|
opset_version=10, |
|
do_constant_folding=True, |
|
input_names = ['input'], |
|
output_names = ['output'], |
|
dynamic_axes={'input' : {0 : 'batch_size'}, |
|
'output' : {0 : 'batch_size'}}) |
|
|
|
print('ONNX model saved to ', onnx_model_path) |
|
|
|
onnx_model = onnx.load(onnx_model_path) |
|
onnx.checker.check_model(onnx_model) |
|
print('ONNX model checked.') |
|
|
|
def to_numpy(tensor): |
|
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
|
|
|
onnx_session = onnxruntime.InferenceSession(onnx_model_path) |
|
|
|
onnx_inputs = {onnx_session.get_inputs()[0].name: to_numpy(x)} |
|
onnx_out = onnx_session.run(None, onnx_inputs) |
|
|
|
np.testing.assert_allclose(to_numpy(torch_out), onnx_out[0], rtol=1e-03, atol=1e-05) |
|
print("Exported model has been tested with ONNXRuntime, and the result looks good!\n") |
|
|
|
else: |
|
pytorch_model_path = os.path.join(save_model_path, 'densenet_' + date + '.pth') |
|
torch.save(model, pytorch_model_path) |
|
print('Pytorch model saved to ', pytorch_model_path) |
|
|
|
|
|
def plot_metrics(hist_dict, results_folder, date): |
|
"""Function for plotting the training and validation results.""" |
|
epochs = range(1, len(hist_dict['tr_loss'])+1) |
|
plt.plot(epochs, hist_dict['tr_loss'], 'g', label='Training loss') |
|
plt.plot(epochs, hist_dict['val_loss'], 'b', label='Validation loss') |
|
plt.title('Training and Validation loss') |
|
plt.xlabel('Epochs') |
|
plt.ylabel('Loss') |
|
plt.legend() |
|
plt.savefig(results_folder + date + '_tr_val_loss.jpg', bbox_inches='tight') |
|
plt.close() |
|
|
|
plt.plot(epochs, hist_dict['tr_acc'], 'g', label='Training accuracy') |
|
plt.plot(epochs, hist_dict['val_acc'], 'b', label='Validation accuracy') |
|
plt.title('Training and Validation accuracy') |
|
plt.xlabel('Epochs') |
|
plt.ylabel('Accuracy') |
|
plt.legend() |
|
plt.savefig(results_folder + date + '_tr_val_acc.jpg', bbox_inches='tight') |
|
plt.close() |
|
|
|
plt.plot(epochs, hist_dict['tr_f1'], 'g', label='Training F1 score') |
|
plt.plot(epochs, hist_dict['val_f1'], 'b', label='Validation F1 score') |
|
plt.title('Training and Validation F1 score') |
|
plt.xlabel('Epochs') |
|
plt.ylabel('F1 score') |
|
plt.legend() |
|
plt.savefig(results_folder + date + '_tr_val_f1.jpg', bbox_inches='tight') |
|
plt.close() |
|
|
|
plt.plot(epochs, hist_dict['lr1'], 'g', label='Backbone learning rate') |
|
plt.plot(epochs, hist_dict['lr2'], 'b', label='Classifier learning rate') |
|
plt.title('Learning rate') |
|
plt.xlabel('Epochs') |
|
plt.ylabel('Learning rate') |
|
plt.legend() |
|
plt.savefig(results_folder + date + '_learning_rate.jpg', bbox_inches='tight') |
|
plt.close() |