MikkoLipsanen commited on
Commit
375fd17
1 Parent(s): 0976156

Upload 5 files

Browse files
Files changed (5) hide show
  1. augment.py +89 -0
  2. requirements.txt +10 -0
  3. test.py +192 -0
  4. train.py +332 -0
  5. utils.py +107 -0
augment.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ import random
3
+ import numpy as np
4
+
5
+ class RandAug:
6
+ """Randomly chosen image augmentations."""
7
+
8
+ def __init__(self, img_size, choice=None):
9
+ # Augmentation options
10
+ self.trans = ['identity', 'rotate', 'color', 'sharpness', 'blur', 'padding' ,'perspective']
11
+ self.img_size = img_size
12
+ self.choice = choice
13
+
14
+ def __call__(self, img):
15
+ if self.choice == None:
16
+ # Weights set 40% probability for the 'identity' augmentation choice
17
+ self.choice = random.choices(self.trans, weights=(40, 10, 10, 10, 10, 10, 10))[0]
18
+
19
+ if self.choice == 'identity':
20
+ trans = transforms.Compose([
21
+ transforms.Resize((self.img_size,self.img_size)),
22
+ transforms.ToTensor()
23
+ ])
24
+ img = trans(img)
25
+
26
+ elif self.choice == 'rotate':
27
+ degrees = random.uniform(0, 180)
28
+ rand_fill = random.choice([0,1])
29
+ trans = transforms.Compose([
30
+ transforms.Resize((self.img_size,self.img_size)),
31
+ transforms.ToTensor(),
32
+ transforms.RandomRotation(degrees, expand=True, fill=rand_fill),
33
+ transforms.Resize((self.img_size,self.img_size))
34
+ ])
35
+ img = trans(img)
36
+
37
+ elif self.choice == 'color':
38
+ rand_brightness = random.uniform(0, 0.3)
39
+ rand_hue = random.uniform(0, 0.5)
40
+ rand_contrast = random.uniform(0, 0.5)
41
+ rand_saturation = random.uniform(0, 0.5)
42
+ trans = transforms.Compose([
43
+ transforms.Resize((self.img_size,self.img_size)),
44
+ transforms.ToTensor(),
45
+ transforms.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue)
46
+ ])
47
+ img = trans(img)
48
+
49
+ elif self.choice=='sharpness':
50
+ sharpness = 1+(np.random.exponential()/2)
51
+ trans = transforms.Compose([
52
+ transforms.Resize((self.img_size,self.img_size)),
53
+ transforms.ToTensor(),
54
+ transforms.RandomAdjustSharpness(sharpness, p=1)
55
+ ])
56
+ img = trans(img)
57
+
58
+ elif self.choice=='blur':
59
+ kernel = random.choice([1,3,5])
60
+ trans = transforms.Compose([
61
+ transforms.Resize((self.img_size,self.img_size)),
62
+ transforms.ToTensor(),
63
+ transforms.GaussianBlur(kernel, sigma=(0.1, 2.0))
64
+ ])
65
+ img = trans(img)
66
+
67
+ elif self.choice=='padding':
68
+ pad = random.choice([3,10,25])
69
+ rand_fill = random.choice([0,1])
70
+ trans = transforms.Compose([
71
+ transforms.Resize((self.img_size,self.img_size)),
72
+ transforms.ToTensor(),
73
+ transforms.Pad(pad, fill=rand_fill, padding_mode='constant'),
74
+ transforms.Resize((self.img_size,self.img_size))
75
+ ])
76
+ img = trans(img)
77
+
78
+ elif self.choice=='perspective':
79
+ scale = random.uniform(0.1, 0.5)
80
+ rand_fill = random.choice([0,1])
81
+ trans = transforms.Compose([
82
+ transforms.Resize((self.img_size,self.img_size)),
83
+ transforms.ToTensor(),
84
+ transforms.RandomPerspective(distortion_scale=scale, p=1.0, fill=rand_fill),
85
+ transforms.Resize((self.img_size,self.img_size))
86
+ ])
87
+ img = trans(img)
88
+
89
+ return img
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch==1.12.1+cu116
3
+ torchvision==0.13.1+cu116
4
+ scikit-learn==1.0.2
5
+ numpy==1.21.6
6
+ pillow==9.3.0
7
+ matplotlib==3.5.3
8
+ onnx==1.13.0
9
+ onnxruntime==1.13.1
10
+ tqdm==4.64.1
test.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ from __future__ import division
3
+ import torch
4
+ import onnxruntime
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torchvision
8
+ from torchvision import transforms
9
+ import matplotlib.pyplot as plt
10
+ from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
11
+ import seaborn as sn
12
+ import random
13
+ import time
14
+ import json
15
+ from PIL import Image
16
+ from PIL import ImageFile
17
+ from pathlib import Path
18
+ import argparse
19
+ print("PyTorch Version: ",torch.__version__)
20
+ print("Torchvision Version: ",torchvision.__version__)
21
+
22
+ parser = argparse.ArgumentParser('arguments for testing the model')
23
+
24
+ parser.add_argument('--ts_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/test/",
25
+ help='path to test data')
26
+ parser.add_argument('--ts_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/test/",
27
+ help='path to test data')
28
+ parser.add_argument('--results_folder', type=str, default="./results/aug_28022024/",
29
+ help='Folder for saving results')
30
+ parser.add_argument('--model_path', type=str, default="/koodit/table_segmentation/empty_cell_detection/train/models/aug_b32_lr0001_28022024.onnx",
31
+ help='path to load model file from')
32
+ parser.add_argument('--batch_size', type=int, default=16,
33
+ help='batch_size')
34
+ parser.add_argument('--num_classes', type=int, default=2,
35
+ help='number of classes for classification')
36
+ parser.add_argument('--name', type=str, default='empty_cell_augment_28022024',
37
+ help='name given to result files')
38
+
39
+ start = time.time()
40
+
41
+ # nohup python test.py > logs/aug_test_28022024.txt 2>&1 &
42
+ # echo $! > output/save_pid.txt
43
+
44
+ torch.manual_seed(67)
45
+ random.seed(67)
46
+
47
+ args = parser.parse_args()
48
+
49
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
50
+ Image.MAX_IMAGE_PIXELS = None
51
+
52
+ # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
53
+
54
+
55
+ def get_data():
56
+ empty_path = Path(args.ts_empty_folder)
57
+ ok_path = Path(args.ts_ok_folder)
58
+
59
+ empty_files = list(empty_path.glob('*.jpg'))
60
+ ok_files = list(ok_path.glob('*.jpg'))
61
+
62
+ empty_labels = np.zeros(len(empty_files))
63
+ ok_labels = np.ones(len(ok_files))
64
+
65
+ #ts_data_files = ts_data_files[:20]
66
+ #ts_data_labels = ts_data_labels[:20]
67
+ #ts_ok_files = ts_ok_files[:20]
68
+ #ts_ok_labels = ts_ok_labels[:20]
69
+
70
+ ts_files = empty_files + ok_files
71
+ ts_labels = np.concatenate((empty_labels, ok_labels))
72
+
73
+ print('Test data with empty cells: ', len(empty_files))
74
+ print('Test data without empty cells: ', len(ok_files))
75
+
76
+ return ts_files, ts_labels
77
+
78
+
79
+ def initialize_model():
80
+ model = onnxruntime.InferenceSession(args.model_path)
81
+ input_size = 224
82
+ return model, input_size
83
+
84
+ # Function for getting precision, recall and F-score metrics
85
+ def get_precision_recall(y_true, y_pred):
86
+ precision_recall_fscore = precision_recall_fscore_support(y_true, y_pred, average=None)
87
+
88
+ prec_0 = precision_recall_fscore[0][0]
89
+ rec_0 = precision_recall_fscore[1][0]
90
+ F_0 = precision_recall_fscore[2][0]
91
+
92
+ prec_1 = precision_recall_fscore[0][1]
93
+ rec_1 = precision_recall_fscore[1][1]
94
+ F_1 = precision_recall_fscore[2][1]
95
+
96
+ print('\nPrecision for ok: %.2f'%prec_1)
97
+ print('Recall for ok: %.2f'%rec_1)
98
+ print('F-score for ok: %.2f'%F_1)
99
+
100
+ print('Precision for empty: %.2f'%prec_0 )
101
+ print('Recall for empty: %.2f'%rec_0)
102
+ print('F-score for empty: %.2f'%F_0)
103
+
104
+
105
+ def createConfusionMatrix(y_true, y_pred):
106
+ classes = np.array(['empty', 'ok'])
107
+
108
+ # Build confusion matrix
109
+ cf_matrix = confusion_matrix(y_true, y_pred)
110
+ print(cf_matrix)
111
+ df_cm = pd.DataFrame(cf_matrix, index=classes,
112
+ columns=classes)
113
+ plt.figure(figsize=(12, 7))
114
+ return sn.heatmap(df_cm, annot=True).get_figure()
115
+
116
+ def save_preds(y_true, y_pred, paths):
117
+ # Identifies images that were not classified correctly
118
+ incorrect_indices = np.where(y_true != y_pred)
119
+ incorrectly_predicted_images = paths[incorrect_indices]
120
+ correct_labels = y_true[incorrect_indices].astype(str)
121
+ incorrect_preds = dict(zip(incorrectly_predicted_images, correct_labels))
122
+
123
+ print(f'{len(incorrect_preds)} incorrect predictions')
124
+
125
+ # Save file names and labels of incorrectly classified images
126
+ with open(args.results_folder + args.name + '_incorrect_preds', "w") as fp:
127
+ json.dump(incorrect_preds, fp)
128
+
129
+ # Initialize the model for this run
130
+ model, input_size = initialize_model()
131
+
132
+ # Print the model we just instantiated
133
+ #print(model_ft)
134
+
135
+ data_transforms = transforms.Compose([
136
+ transforms.Resize((input_size, input_size)),
137
+ transforms.ToTensor()
138
+ ])
139
+
140
+ print("Initializing Datasets and Dataloaders...")
141
+
142
+ ts_files, ts_labels = get_data()
143
+
144
+ # Function for getting model predictions on test data
145
+ def test_model(model, ts_files, ts_labels):
146
+ since = time.time()
147
+ label_preds = []
148
+ true_labels = []
149
+ paths = []
150
+ n = len(ts_files)
151
+ # Iterate over data
152
+ for i in range(n):
153
+ print(f'{i}/{n}')
154
+ image = Image.open(ts_files[i])
155
+ label = ts_labels[i]
156
+ image = data_transforms(image.convert("RGB")).unsqueeze(0)
157
+ # Transform tensor to numpy array
158
+ img = image.detach().cpu().numpy()
159
+ input = {model.get_inputs()[0].name: img}
160
+ # Run model prediction
161
+ output = model.run(None, input)
162
+ # Get predicted class
163
+ pred = np.argmax(output[0], 1)
164
+ pred_class = pred.item()
165
+ label_preds.append(pred_class)
166
+ true_labels.append(label)
167
+ paths.append(str(ts_files[i]))
168
+
169
+ time_elapsed = time.time() - since
170
+ print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
171
+
172
+ return np.array(label_preds), np.array(true_labels), np.array(paths)
173
+
174
+ ts_labels = np.array(ts_labels)
175
+
176
+ # Test model
177
+ y_pred, y_true, paths = test_model(model, ts_files, ts_labels)
178
+ # Saves information of incorrect predictions
179
+ save_preds(y_true, y_pred, paths)
180
+ # Calculates and prints precision, recall and F-score metrics
181
+ get_precision_recall(y_true, y_pred)
182
+
183
+ # Save confusion matrix to Tensorboard
184
+ #cm = createConfusionMatrix(y_true, y_pred)
185
+ #writer.add_figure("Confusion matrix", cm)
186
+ # Create and save confusion matrix of the predictions and true labels
187
+ conf_matrix = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, normalize='true', display_labels=np.array(['empty', 'ok']))
188
+ plt.savefig(args.results_folder + args.name + '_conf_matrix.jpg', bbox_inches='tight')
189
+
190
+ end = time.time()
191
+ time_in_mins = (end - start) / 60
192
+ print('Time: %.2f minutes' % time_in_mins)
train.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ from __future__ import division
3
+ import torch
4
+ import torchvision
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import models
8
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
9
+ from sklearn.utils import class_weight
10
+ from sklearn.metrics import precision_recall_fscore_support
11
+ import numpy as np
12
+ import time
13
+ import argparse
14
+ from tqdm import tqdm
15
+ from PIL import Image, ImageFile
16
+ from pathlib import Path
17
+
18
+ from augment import RandAug
19
+ import utils
20
+
21
+ print("PyTorch Version: ",torch.__version__)
22
+ print("Torchvision Version: ",torchvision.__version__)
23
+
24
+ # Much of the code is a modified version of the code available at
25
+ # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
26
+
27
+
28
+ # nohup python train.py > logs/empty_cell_aug_28032024.txt 2>&1 &
29
+ # echo $! > logs/save_pid.txt
30
+
31
+ parser = argparse.ArgumentParser('arguments for training')
32
+
33
+ parser.add_argument('--tr_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/train/",
34
+ help='path to training data with empty images')
35
+ parser.add_argument('--val_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/val/",
36
+ help='path to validation data with empty images')
37
+ parser.add_argument('--tr_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/train/",
38
+ help='path to training data with ok images')
39
+ parser.add_argument('--val_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/val/",
40
+ help='path to validation data with ok images')
41
+ parser.add_argument('--results_folder', type=str, default="results/28032024_aug/",
42
+ help='Folder for saving training results.')
43
+ parser.add_argument('--save_model_path', type=str, default="./models/",
44
+ help='Path for saving model file.')
45
+ parser.add_argument('--batch_size', type=int, default=32,
46
+ help='Batch size used for model training. ')
47
+ parser.add_argument('--lr', type=float, default=0.0001,
48
+ help='Base learning rate.')
49
+ parser.add_argument('--device', type=str, default='cpu',
50
+ help='Defines whether the model is trained using cpu or gpu.')
51
+ parser.add_argument('--num_classes', type=int, default=2,
52
+ help='Number of classes used in classification.')
53
+ parser.add_argument('--num_epochs', type=int, default=15,
54
+ help='Number of training epochs.')
55
+ parser.add_argument('--random_seed', type=int, default=8765,
56
+ help='Number used for initializing random number generation.')
57
+ parser.add_argument('--early_stop_threshold', type=int, default=3,
58
+ help='Threshold value of epochs after which training stops if validation accuracy does not improve.')
59
+ parser.add_argument('--save_model_format', type=str, default='torch',
60
+ help='Defines the format for saving the model.')
61
+ parser.add_argument('--augment_choice', type=str, default=None,
62
+ help='Defines which image augmentation(s) are used. Defaults to randomly selected augmentations.')
63
+ parser.add_argument('--model_name', type=str, default='aug_b32_lr0001',
64
+ help='Current date.')
65
+ parser.add_argument('--date', type=str, default=time.strftime("%d%m%Y"),
66
+ help='Current date.')
67
+
68
+ args = parser.parse_args()
69
+
70
+ # PIL settings to avoid errors caused by truncated and large images
71
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
72
+ Image.MAX_IMAGE_PIXELS = None
73
+
74
+ # List for saving the names of damaged images
75
+ damaged_images = []
76
+
77
+ def get_datapaths():
78
+ """Function for loading train and validation data."""
79
+ tr_empty_files = list(Path(args.tr_empty_folder).glob('*'))
80
+ tr_ok_files = list(Path(args.tr_ok_folder).glob('*'))
81
+ val_empty_files = list(Path(args.val_empty_folder).glob('*'))
82
+ val_ok_files = list(Path(args.val_ok_folder).glob('*'))
83
+ # Create labels for train and validation data
84
+ tr_labels = np.concatenate((np.zeros(len(tr_empty_files)), np.ones(len(tr_ok_files))))
85
+ val_labels = np.concatenate((np.zeros(len(val_empty_files)), np.ones(len(val_ok_files))))
86
+ # Combine faulty and non-faulty images
87
+ tr_files = tr_empty_files + tr_ok_files
88
+ val_files = val_empty_files + val_ok_files
89
+
90
+ print('\nTraining data with empty cells: ', len(tr_empty_files))
91
+ print('Training data without empty cells: ', len(tr_ok_files))
92
+
93
+ print('Validation data with empty cells: ', len(val_empty_files))
94
+ print('Validation data without empty cells: ', len(val_ok_files))
95
+
96
+ data_dict = {'tr_data': tr_files, 'tr_labels': tr_labels,
97
+ 'val_data': val_files, 'val_labels': val_labels}
98
+
99
+ return data_dict
100
+
101
+ class ImageDataset(Dataset):
102
+ """PyTorch Dataset class is used for generating training and validation datasets."""
103
+ def __init__(self, img_paths, img_labels, transform=None, target_transform=None):
104
+ self.img_paths = img_paths
105
+ self.img_labels = img_labels
106
+ self.transform = transform
107
+ self.target_transform = target_transform
108
+
109
+ def __len__(self):
110
+ return len(self.img_labels)
111
+
112
+ def __getitem__(self, idx):
113
+ img_path = self.img_paths[idx]
114
+ try:
115
+ image = Image.open(img_path).convert('RGB')
116
+ label = self.img_labels[idx]
117
+ except:
118
+ # Image is considered damaged if reading the image fails
119
+ damaged_images.append(img_path)
120
+ return None
121
+ if self.transform:
122
+ image = self.transform(image.convert("RGB"))
123
+ if self.target_transform:
124
+ label = self.target_transform(label)
125
+
126
+ return image, label
127
+
128
+ def initialize_model():
129
+ """Function for initializing pretrained neural network model (DenseNet121)."""
130
+ model_ft = models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1)
131
+ num_ftrs = model_ft.classifier.in_features
132
+ model_ft.classifier = nn.Linear(num_ftrs, args.num_classes)
133
+ input_size = 224
134
+
135
+ return model_ft, input_size
136
+
137
+ def collate_fn(batch):
138
+ """Helper function for creating data batches."""
139
+ batch = list(filter(lambda x: x is not None, batch))
140
+
141
+ return torch.utils.data.dataloader.default_collate(batch)
142
+
143
+ def initialize_dataloaders(data_dict, input_size):
144
+ """Function for initializing datasets and dataloaders."""
145
+ # Train and validation datasets
146
+ train_dataset = ImageDataset(img_paths=data_dict['tr_data'], img_labels=data_dict['tr_labels'], transform=RandAug(input_size, args.augment_choice))
147
+ validation_dataset = ImageDataset(img_paths=data_dict['val_data'], img_labels=data_dict['val_labels'], transform=RandAug(input_size, 'identity'))
148
+ # Train and validation dataloaders
149
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
150
+ validation_dataloader = DataLoader(validation_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
151
+
152
+ return {'train': train_dataloader, 'val': validation_dataloader}
153
+
154
+ def get_criterion(data_dict):
155
+ """Function for generating class weights and initializing the loss function."""
156
+ y = np.asarray(data_dict['tr_labels'])
157
+ # Class weights are used for compensating the unbalance
158
+ # in the number of training data from the two classes
159
+ class_weights=class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y), y=y)
160
+ class_weights=torch.tensor(class_weights, dtype=torch.float).to(args.device)
161
+ print('\nClass weights: ', class_weights.tolist())
162
+ # Cross Entropy Loss function
163
+ criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
164
+
165
+ return criterion
166
+
167
+ def get_optimizer(model):
168
+ """Function for initializing the optimizer."""
169
+ # Model parameters are split into two groups: parameters of the classifier
170
+ # layer and other model parameters
171
+ params_1 = [param for name, param in model.named_parameters()
172
+ if name not in ["classifier.weight", "classifier.bias"]]
173
+ params_2 = model.classifier.parameters()
174
+ # 10 x larger learning rate is used when training the parameters
175
+ # of the classification layers
176
+ params_to_update = [
177
+ {'params': params_1, 'lr': args.lr},
178
+ {'params': params_2, 'lr': args.lr * 10}
179
+ ]
180
+ # Adam optimizer
181
+ optimizer = torch.optim.Adam(params_to_update, args.lr)
182
+ # Scheduler reduces learning rate when validation accuracy does not improve for an epoch
183
+ scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True)
184
+
185
+ return optimizer, scheduler
186
+
187
+ def train_model(model, dataloaders, criterion, optimizer, scheduler=None):
188
+ """Function for model training and validation."""
189
+ since = time.time()
190
+ # Lists for saving train and validation metrics for each epoch
191
+ tr_loss_history = []
192
+ tr_acc_history = []
193
+ tr_f1_history = []
194
+ val_loss_history = []
195
+ val_acc_history = []
196
+ val_f1_history = []
197
+ # Lists for saving learning rates for the 2 parameter groups
198
+ lr1_history = []
199
+ lr2_history = []
200
+
201
+ # Best F1 value and best epoch are saved in variables
202
+ best_f1 = 0
203
+ best_epoch = 0
204
+ early_stop = False
205
+
206
+ # Train / validation loop
207
+ for epoch in tqdm(range(args.num_epochs)):
208
+ # Save learning rates for the epoch
209
+ lr1_history.append(optimizer.param_groups[0]["lr"])
210
+ lr2_history.append(optimizer.param_groups[1]["lr"])
211
+
212
+ print('Epoch {}/{}'.format(epoch+1, args.num_epochs))
213
+ print('-' * 10)
214
+
215
+ # Each epoch has a training and validation phase
216
+ for phase in ['train', 'val']:
217
+ if phase == 'train':
218
+ model.train() # Set model to training mode
219
+ else:
220
+ model.eval() # Set model to evaluate mode
221
+
222
+ running_loss = 0.0
223
+ running_corrects = 0
224
+ running_f1 = 0.0
225
+
226
+ # Iterate over data in batch
227
+ for inputs, labels in dataloaders[phase]:
228
+ if dataloaders[phase] is None:
229
+ continue
230
+ else:
231
+ inputs = inputs.to(args.device)
232
+ labels = labels.long().to(args.device)
233
+
234
+ # Zero the parameter gradients
235
+ optimizer.zero_grad()
236
+
237
+ # Track history only in training phase
238
+ with torch.set_grad_enabled(phase == 'train'):
239
+ # Get model outputs and calculate loss
240
+ outputs = model(inputs)
241
+ loss = criterion(outputs, labels)
242
+ # Model predictions of the image labels for the batch
243
+ _, preds = torch.max(outputs, 1)
244
+
245
+ # Backward + optimize only if in training phase
246
+ if phase == 'train':
247
+ loss.backward()
248
+ optimizer.step()
249
+
250
+ # Get weighted F1 score for the results
251
+ precision_recall_fscore = precision_recall_fscore_support(labels.data.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='weighted', zero_division=0)
252
+ f1_score = precision_recall_fscore[2]
253
+
254
+ # update statistics
255
+ running_loss += loss.item() * inputs.size(0)
256
+ running_corrects += torch.sum(preds == labels.data).cpu()
257
+ running_f1 += f1_score
258
+
259
+ # Calculate loss, accuracy and F1 score for the epoch
260
+ epoch_loss = running_loss / len(dataloaders[phase].dataset)
261
+ epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
262
+ epoch_f1 = running_f1 / len(dataloaders[phase])
263
+
264
+ print('\nEpoch {} - {} - Loss: {:.4f} Acc: {:.4f} F1: {:.4f}\n'.format(epoch+1, phase, epoch_loss, epoch_acc, epoch_f1))
265
+
266
+ # Validation step
267
+ if phase == 'val':
268
+ val_acc_history.append(epoch_acc)
269
+ val_loss_history.append(epoch_loss)
270
+ val_f1_history.append(epoch_f1)
271
+ if epoch_f1 > best_f1:
272
+ print('\nF1 score {:.4f} improved from {:.4f}. Saving the model.\n'.format(epoch_f1, best_f1))
273
+ # Model with best F1 score is saved
274
+ utils.save_model(model, 224, args.save_model_format, args.save_model_path, args.model_name, args.date)
275
+ model = model.to(args.device)
276
+ best_f1 = epoch_f1
277
+ best_epoch = epoch
278
+ elif epoch - best_epoch > args.early_stop_threshold:
279
+ # terminates the training loop if validation accuracy has not improved
280
+ print("Early stopped training at epoch %d" % epoch)
281
+ # Set early stopping condition
282
+ early_stop = True
283
+ break
284
+ elif phase == 'train':
285
+ tr_acc_history.append(epoch_acc)
286
+ tr_loss_history.append(epoch_loss)
287
+ tr_f1_history.append(epoch_f1)
288
+
289
+ # Break outer loop if early stopping condition is activated
290
+ if early_stop:
291
+ break
292
+ # Take scheduler step
293
+ if scheduler:
294
+ scheduler.step(val_f1_history[-1])
295
+
296
+ time_elapsed = time.time() - since
297
+ print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
298
+ print('Best validation F1 score: {:.4f}'.format(best_f1))
299
+ # Returns model with the weights from the best epoch (based on validation accuracy)
300
+ hist_dict = {'tr_acc': tr_acc_history,
301
+ 'val_acc': val_acc_history,
302
+ 'val_loss': val_loss_history,
303
+ 'val_f1': val_f1_history,
304
+ 'tr_loss': tr_loss_history,
305
+ 'tr_f1': tr_f1_history,
306
+ 'lr1': lr1_history,
307
+ 'lr2': lr2_history}
308
+
309
+ return hist_dict
310
+
311
+ def main():
312
+ # Set random seed(s)
313
+ utils.set_seed(args.random_seed)
314
+ # Load image paths and labels
315
+ data_dict = get_datapaths()
316
+ # Initialize the model
317
+ model, input_size = initialize_model()
318
+ # Print the model architecture
319
+ #print(model_ft)
320
+ # Send the model to GPU (if available)
321
+ model = model.to(args.device)
322
+ print("\nInitializing Datasets and Dataloaders...")
323
+ dataloaders_dict = initialize_dataloaders(data_dict, input_size)
324
+ criterion = get_criterion(data_dict)
325
+ optimizer, scheduler = get_optimizer(model)
326
+ # Train and evaluate model
327
+ hist_dict = train_model(model, dataloaders_dict, criterion, optimizer, scheduler)
328
+ print('Damaged images: ', damaged_images)
329
+ utils.plot_metrics(hist_dict, args.results_folder, args.date)
330
+
331
+ if __name__ == '__main__':
332
+ main()
utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import onnx
3
+ import onnxruntime
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import random
8
+
9
+ def set_seed(random_seed):
10
+ """Function for setting random seed for the relevant libraries."""
11
+ np.random.seed(random_seed)
12
+ random.seed(random_seed)
13
+ torch.manual_seed(random_seed)
14
+ torch.cuda.manual_seed(random_seed)
15
+ # When running on the CuDNN backend, two further options must be set
16
+ torch.backends.cudnn.deterministic = True
17
+ torch.backends.cudnn.benchmark = False
18
+ # Set a fixed value for the hash seed
19
+ os.environ["PYTHONHASHSEED"] = str(random_seed)
20
+ print(f"Random seed set as {random_seed}")
21
+
22
+ def save_model(model, input_size, save_model_format, save_model_path, model_name, date):
23
+ """Function for saving the model in .pth or .onnx format.
24
+ Code modified from
25
+ https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html"""
26
+ if save_model_format == 'onnx':
27
+ onnx_model_path = os.path.join(save_model_path, model_name + '_' + date + '.onnx')
28
+ # Random batch size
29
+ batch_size = 1
30
+ # Random input to the model (with correct dimensions)
31
+ x = torch.randn(batch_size, 3, input_size, input_size, requires_grad=True)
32
+ model = model.to('cpu')
33
+ torch_out = model(x)
34
+
35
+ # Export the model
36
+ torch.onnx.export(model, # model being run
37
+ x, # model input (or a tuple for multiple inputs)
38
+ onnx_model_path, # where to save the model (can be a file or file-like object)
39
+ export_params=True, # store the trained parameter weights inside the model file
40
+ opset_version=10, # the ONNX version to export the model to
41
+ do_constant_folding=True, # whether to execute constant folding for optimization
42
+ input_names = ['input'], # the model's input names
43
+ output_names = ['output'], # the model's output names
44
+ dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
45
+ 'output' : {0 : 'batch_size'}})
46
+
47
+ print('ONNX model saved to ', onnx_model_path)
48
+ # Test transformed model
49
+ onnx_model = onnx.load(onnx_model_path)
50
+ onnx.checker.check_model(onnx_model)
51
+ print('ONNX model checked.')
52
+
53
+ def to_numpy(tensor):
54
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
55
+
56
+ onnx_session = onnxruntime.InferenceSession(onnx_model_path)
57
+ # compute ONNX Runtime output prediction
58
+ onnx_inputs = {onnx_session.get_inputs()[0].name: to_numpy(x)}
59
+ onnx_out = onnx_session.run(None, onnx_inputs)
60
+ # compare ONNX Runtime and PyTorch results
61
+ np.testing.assert_allclose(to_numpy(torch_out), onnx_out[0], rtol=1e-03, atol=1e-05)
62
+ print("Exported model has been tested with ONNXRuntime, and the result looks good!\n")
63
+
64
+ else:
65
+ pytorch_model_path = os.path.join(save_model_path, 'densenet_' + date + '.pth')
66
+ torch.save(model, pytorch_model_path)
67
+ print('Pytorch model saved to ', pytorch_model_path)
68
+
69
+
70
+ def plot_metrics(hist_dict, results_folder, date):
71
+ """Function for plotting the training and validation results."""
72
+ epochs = range(1, len(hist_dict['tr_loss'])+1)
73
+ plt.plot(epochs, hist_dict['tr_loss'], 'g', label='Training loss')
74
+ plt.plot(epochs, hist_dict['val_loss'], 'b', label='Validation loss')
75
+ plt.title('Training and Validation loss')
76
+ plt.xlabel('Epochs')
77
+ plt.ylabel('Loss')
78
+ plt.legend()
79
+ plt.savefig(results_folder + date + '_tr_val_loss.jpg', bbox_inches='tight')
80
+ plt.close()
81
+
82
+ plt.plot(epochs, hist_dict['tr_acc'], 'g', label='Training accuracy')
83
+ plt.plot(epochs, hist_dict['val_acc'], 'b', label='Validation accuracy')
84
+ plt.title('Training and Validation accuracy')
85
+ plt.xlabel('Epochs')
86
+ plt.ylabel('Accuracy')
87
+ plt.legend()
88
+ plt.savefig(results_folder + date + '_tr_val_acc.jpg', bbox_inches='tight')
89
+ plt.close()
90
+
91
+ plt.plot(epochs, hist_dict['tr_f1'], 'g', label='Training F1 score')
92
+ plt.plot(epochs, hist_dict['val_f1'], 'b', label='Validation F1 score')
93
+ plt.title('Training and Validation F1 score')
94
+ plt.xlabel('Epochs')
95
+ plt.ylabel('F1 score')
96
+ plt.legend()
97
+ plt.savefig(results_folder + date + '_tr_val_f1.jpg', bbox_inches='tight')
98
+ plt.close()
99
+
100
+ plt.plot(epochs, hist_dict['lr1'], 'g', label='Backbone learning rate')
101
+ plt.plot(epochs, hist_dict['lr2'], 'b', label='Classifier learning rate')
102
+ plt.title('Learning rate')
103
+ plt.xlabel('Epochs')
104
+ plt.ylabel('Learning rate')
105
+ plt.legend()
106
+ plt.savefig(results_folder + date + '_learning_rate.jpg', bbox_inches='tight')
107
+ plt.close()