MikkoLipsanen's picture
Update test.py
93a3c63 verified
raw
history blame
No virus
6.16 kB
from __future__ import print_function
from __future__ import division
import torch
import onnxruntime
import numpy as np
import pandas as pd
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import seaborn as sn
import random
import time
import json
from PIL import Image
from PIL import ImageFile
from pathlib import Path
import argparse
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
parser = argparse.ArgumentParser('arguments for testing the model')
parser.add_argument('--ts_empty_folder', type=str, default="/path/to/empty/test/data/",
help='path to test data')
parser.add_argument('--ts_ok_folder', type=str, default="/path/to/non-empty/test/data/",
help='path to test data')
parser.add_argument('--results_folder', type=str, default="./results/",
help='Folder for saving results')
parser.add_argument('--model_path', type=str, default="/path/to/model.onnx",
help='path to load model file from')
parser.add_argument('--batch_size', type=int, default=16,
help='batch_size')
parser.add_argument('--num_classes', type=int, default=2,
help='number of classes for classification')
parser.add_argument('--name', type=str, default='test',
help='name given to result files')
start = time.time()
torch.manual_seed(67)
random.seed(67)
args = parser.parse_args()
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
def get_data():
"""Combines test data paths and labels"""
empty_path = Path(args.ts_empty_folder)
ok_path = Path(args.ts_ok_folder)
empty_files = list(empty_path.glob('*.jpg'))
ok_files = list(ok_path.glob('*.jpg'))
empty_labels = np.zeros(len(empty_files))
ok_labels = np.ones(len(ok_files))
ts_files = empty_files + ok_files
ts_labels = np.concatenate((empty_labels, ok_labels))
print('Test data with empty cells: ', len(empty_files))
print('Test data without empty cells: ', len(ok_files))
return ts_files, ts_labels
def initialize_model():
"""Initializes .onnx model."""
model = onnxruntime.InferenceSession(args.model_path)
input_size = 224
return model, input_size
def get_precision_recall(y_true, y_pred):
"""Calculates precision, recall and F-score metrics."""
precision_recall_fscore = precision_recall_fscore_support(y_true, y_pred, average=None)
prec_0 = precision_recall_fscore[0][0]
rec_0 = precision_recall_fscore[1][0]
F_0 = precision_recall_fscore[2][0]
prec_1 = precision_recall_fscore[0][1]
rec_1 = precision_recall_fscore[1][1]
F_1 = precision_recall_fscore[2][1]
print('\nPrecision for ok: %.2f'%prec_1)
print('Recall for ok: %.2f'%rec_1)
print('F-score for ok: %.2f'%F_1)
print('Precision for empty: %.2f'%prec_0 )
print('Recall for empty: %.2f'%rec_0)
print('F-score for empty: %.2f'%F_0)
def createConfusionMatrix(y_true, y_pred):
"""Creates confusion matrix based on the predicted and true labels."""
classes = np.array(['empty', 'ok'])
# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
print(cf_matrix)
df_cm = pd.DataFrame(cf_matrix, index=classes,
columns=classes)
plt.figure(figsize=(12, 7))
return sn.heatmap(df_cm, annot=True).get_figure()
def save_preds(y_true, y_pred, paths):
"""Saves file names and labels of incorrectly classified images."""
# Identifies images that were not classified correctly
incorrect_indices = np.where(y_true != y_pred)
incorrectly_predicted_images = paths[incorrect_indices]
correct_labels = y_true[incorrect_indices].astype(str)
incorrect_preds = dict(zip(incorrectly_predicted_images, correct_labels))
print(f'{len(incorrect_preds)} incorrect predictions')
with open(args.results_folder + args.name + '_incorrect_preds', "w") as fp:
json.dump(incorrect_preds, fp)
# Initialize the model for this run
model, input_size = initialize_model()
data_transforms = transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor()
])
print("Initializing Datasets and Dataloaders...")
ts_files, ts_labels = get_data()
def test_model(model, ts_files, ts_labels):
"""Get model predictions on test data."""
since = time.time()
label_preds = []
true_labels = []
paths = []
n = len(ts_files)
# Iterate over data
for i in range(n):
print(f'{i}/{n}')
image = Image.open(ts_files[i])
label = ts_labels[i]
image = data_transforms(image.convert("RGB")).unsqueeze(0)
# Transform tensor to numpy array
img = image.detach().cpu().numpy()
input = {model.get_inputs()[0].name: img}
# Run model prediction
output = model.run(None, input)
# Get predicted class
pred = np.argmax(output[0], 1)
pred_class = pred.item()
label_preds.append(pred_class)
true_labels.append(label)
paths.append(str(ts_files[i]))
time_elapsed = time.time() - since
print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
return np.array(label_preds), np.array(true_labels), np.array(paths)
ts_labels = np.array(ts_labels)
# Test model
y_pred, y_true, paths = test_model(model, ts_files, ts_labels)
# Save information of incorrect predictions
save_preds(y_true, y_pred, paths)
# Calculate and print precision, recall and F-score metrics
get_precision_recall(y_true, y_pred)
# Create and save confusion matrix of the predictions and true labels
conf_matrix = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, normalize='true', display_labels=np.array(['empty', 'ok']))
plt.savefig(args.results_folder + args.name + '_conf_matrix.jpg', bbox_inches='tight')
end = time.time()
time_in_mins = (end - start) / 60
print('Time: %.2f minutes' % time_in_mins)