File size: 14,733 Bytes
375fd17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9377114
375fd17
9377114
375fd17
9377114
375fd17
9377114
375fd17
9377114
375fd17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9377114
375fd17
 
 
9377114
375fd17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
from __future__ import print_function
from __future__ import division
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import  models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.utils import class_weight
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import time
import argparse
from tqdm import tqdm
from PIL import Image, ImageFile
from pathlib import Path

from augment import RandAug
import utils

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

# Much of the code is a modified version of the code available at
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

parser = argparse.ArgumentParser('arguments for training')

parser.add_argument('--tr_empty_folder', type=str, default="/path/to/empty/train/images",
                    help='path to training data with empty images')
parser.add_argument('--val_empty_folder', type=str, default="/path/to/empty/validation/images",
                    help='path to validation data with empty images')
parser.add_argument('--tr_ok_folder', type=str, default="/path/to/non-empty/train/images",
                    help='path to training data with ok images')
parser.add_argument('--val_ok_folder', type=str, default="/path/to/non-empty/validation/images",
                    help='path to validation data with ok images')
parser.add_argument('--results_folder', type=str, default="./results/",
                    help='Folder for saving training results.')
parser.add_argument('--save_model_path', type=str, default="./models/",
                    help='Path for saving model file.')
parser.add_argument('--batch_size', type=int, default=32,
                    help='Batch size used for model training. ')
parser.add_argument('--lr', type=float, default=0.0001,
                    help='Base learning rate.')
parser.add_argument('--device', type=str, default='cpu',
                    help='Defines whether the model is trained using cpu or gpu.')
parser.add_argument('--num_classes', type=int, default=2,
                    help='Number of classes used in classification.')
parser.add_argument('--num_epochs', type=int, default=15,
                    help='Number of training epochs.')
parser.add_argument('--random_seed', type=int, default=8765,
                    help='Number used for initializing random number generation.')
parser.add_argument('--early_stop_threshold', type=int, default=3,
                    help='Threshold value of epochs after which training stops if validation accuracy does not improve.')
parser.add_argument('--save_model_format', type=str, default='onnx',
                    help='Defines the format for saving the model.')
parser.add_argument('--augment_choice', type=str, default=None,
                    help='Defines which image augmentation(s) are used. Defaults to randomly selected augmentations.')
parser.add_argument('--model_name', type=str, default='test_model',
                    help='Current date.')
parser.add_argument('--date', type=str, default=time.strftime("%d%m%Y"),
                    help='Current date.')

args = parser.parse_args()

# PIL settings to avoid errors caused by truncated and large images
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

# List for saving the names of damaged images
damaged_images = []

def get_datapaths():
    """Function for loading train and validation data."""
    tr_empty_files = list(Path(args.tr_empty_folder).glob('*'))
    tr_ok_files = list(Path(args.tr_ok_folder).glob('*'))
    val_empty_files = list(Path(args.val_empty_folder).glob('*'))
    val_ok_files = list(Path(args.val_ok_folder).glob('*'))
    # Create labels for train and validation data
    tr_labels = np.concatenate((np.zeros(len(tr_empty_files)), np.ones(len(tr_ok_files))))
    val_labels = np.concatenate((np.zeros(len(val_empty_files)), np.ones(len(val_ok_files))))
    # Combine faulty and non-faulty images
    tr_files = tr_empty_files + tr_ok_files
    val_files = val_empty_files + val_ok_files

    print('\nTraining data with empty cells: ', len(tr_empty_files))
    print('Training data without empty cells: ', len(tr_ok_files))

    print('Validation data with empty cells: ', len(val_empty_files))
    print('Validation data without empty cells: ', len(val_ok_files))

    data_dict = {'tr_data': tr_files, 'tr_labels': tr_labels, 
            'val_data': val_files, 'val_labels': val_labels}

    return data_dict

class ImageDataset(Dataset):
    """PyTorch Dataset class is used for generating training and validation datasets."""
    def __init__(self, img_paths, img_labels, transform=None, target_transform=None):
        self.img_paths = img_paths
        self.img_labels = img_labels
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            label = self.img_labels[idx]
        except:
            # Image is considered damaged if reading the image fails
            damaged_images.append(img_path)
            return None
        if self.transform:
            image = self.transform(image.convert("RGB"))
        if self.target_transform:
            label = self.target_transform(label)
            
        return image, label

def initialize_model():
    """Function for initializing pretrained neural network model (DenseNet121)."""
    model_ft = models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Linear(num_ftrs, args.num_classes)
    input_size = 224

    return model_ft, input_size

def collate_fn(batch):
    """Helper function for creating data batches."""
    batch = list(filter(lambda x: x is not None, batch))
 
    return torch.utils.data.dataloader.default_collate(batch)

def initialize_dataloaders(data_dict, input_size):
    """Function for initializing datasets and dataloaders."""
    # Train and validation datasets 
    train_dataset = ImageDataset(img_paths=data_dict['tr_data'], img_labels=data_dict['tr_labels'],  transform=RandAug(input_size, args.augment_choice))
    validation_dataset = ImageDataset(img_paths=data_dict['val_data'], img_labels=data_dict['val_labels'], transform=RandAug(input_size, 'identity'))
    # Train and validation dataloaders
    train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
    validation_dataloader = DataLoader(validation_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
    
    return {'train': train_dataloader, 'val': validation_dataloader}

def get_criterion(data_dict):
    """Function for generating class weights and initializing the loss function."""
    y = np.asarray(data_dict['tr_labels'])
    # Class weights are used for compensating the unbalance 
    # in the number of training data from the two classes
    class_weights=class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y), y=y)
    class_weights=torch.tensor(class_weights, dtype=torch.float).to(args.device)
    print('\nClass weights: ', class_weights.tolist())
     # Cross Entropy Loss function
    criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='mean')

    return criterion

def get_optimizer(model):
    """Function for initializing the optimizer."""
    # Model parameters are split into two groups: parameters of the classifier
    # layer and other model parameters
    params_1 = [param for name, param in model.named_parameters()
                if name not in ["classifier.weight", "classifier.bias"]]
    params_2 = model.classifier.parameters()
    # 10 x larger learning rate is used when training the parameters 
    # of the classification layers
    params_to_update = [
            {'params': params_1, 'lr': args.lr},
            {'params': params_2, 'lr': args.lr * 10}
            ]
    # Adam optimizer
    optimizer = torch.optim.Adam(params_to_update, args.lr)
    # Scheduler reduces learning rate when validation accuracy does not improve for an epoch
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True)

    return optimizer, scheduler

def train_model(model, dataloaders, criterion, optimizer, scheduler=None):
    """Function for model training and validation."""
    since = time.time()
    # Lists for saving train and validation metrics for each epoch
    tr_loss_history = []
    tr_acc_history = []
    tr_f1_history = []
    val_loss_history = []
    val_acc_history = []
    val_f1_history = []
    # Lists for saving learning rates for the 2 parameter groups
    lr1_history = []
    lr2_history = []
    
    # Best F1 value and best epoch are saved in variables
    best_f1 = 0
    best_epoch = 0
    early_stop = False

    # Train / validation loop
    for epoch in tqdm(range(args.num_epochs)):
        # Save learning rates for the epoch
        lr1_history.append(optimizer.param_groups[0]["lr"])
        lr2_history.append(optimizer.param_groups[1]["lr"])
        
        print('Epoch {}/{}'.format(epoch+1, args.num_epochs))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            running_f1 = 0.0

            # Iterate over data in batch
            for inputs, labels in dataloaders[phase]:
                if dataloaders[phase] is None:
                    continue
                else:
                    inputs = inputs.to(args.device)
                    labels = labels.long().to(args.device)

                    # Zero the parameter gradients
                    optimizer.zero_grad()

                    # Track history only in training phase
                    with torch.set_grad_enabled(phase == 'train'):
                        # Get model outputs and calculate loss
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        # Model predictions of the image labels for the batch
                        _, preds = torch.max(outputs, 1)

                        # Backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    
                    # Get weighted F1 score for the results
                    precision_recall_fscore = precision_recall_fscore_support(labels.data.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='weighted', zero_division=0)
                    f1_score = precision_recall_fscore[2]

                    # update statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data).cpu()
                    running_f1 += f1_score

            # Calculate loss, accuracy and F1 score for the epoch
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            epoch_f1 = running_f1 / len(dataloaders[phase])

            print('\nEpoch {} - {} - Loss: {:.4f} Acc: {:.4f} F1: {:.4f}\n'.format(epoch+1, phase, epoch_loss, epoch_acc, epoch_f1))
            
            # Validation step
            if phase == 'val':
                val_acc_history.append(epoch_acc)
                val_loss_history.append(epoch_loss)
                val_f1_history.append(epoch_f1)
                if epoch_f1 > best_f1:
                    print('\nF1 score {:.4f} improved from {:.4f}. Saving the model.\n'.format(epoch_f1, best_f1))
                    # Model with best F1 score is saved
                    utils.save_model(model, 224, args.save_model_format, args.save_model_path, args.model_name, args.date)
                    model = model.to(args.device)
                    best_f1 = epoch_f1
                    best_epoch = epoch
                elif epoch - best_epoch > args.early_stop_threshold:
                    # terminates the training loop if validation accuracy has not improved
                    print("Early stopped training at epoch %d" % epoch)
                    # Set early stopping condition
                    early_stop = True
                    break  
            elif phase == 'train':
                tr_acc_history.append(epoch_acc)
                tr_loss_history.append(epoch_loss)
                tr_f1_history.append(epoch_f1)

        # Break outer loop if early stopping condition is activated
        if early_stop:
            break
        # Take scheduler step
        if scheduler:
            scheduler.step(val_f1_history[-1])

    time_elapsed = time.time() - since
    print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best validation F1 score: {:.4f}'.format(best_f1))
    # Returns model with the weights from the best epoch (based on validation accuracy)
    hist_dict = {'tr_acc': tr_acc_history, 
                 'val_acc': val_acc_history, 
                 'val_loss': val_loss_history,
                 'val_f1': val_f1_history,
                 'tr_loss': tr_loss_history,
                 'tr_f1': tr_f1_history,
                 'lr1': lr1_history,
                 'lr2': lr2_history}

    return hist_dict

def main():
    # Set random seed(s)
    utils.set_seed(args.random_seed)
    # Load image paths and labels
    data_dict = get_datapaths()
    # Initialize the model 
    model, input_size = initialize_model()
    # Print the model architecture
    #print(model_ft)
    # Send the model to GPU (if available)
    model = model.to(args.device)
    print("\nInitializing Datasets and Dataloaders...")
    dataloaders_dict = initialize_dataloaders(data_dict, input_size)
    criterion = get_criterion(data_dict)
    optimizer, scheduler = get_optimizer(model)
    # Train and evaluate model
    hist_dict = train_model(model, dataloaders_dict, criterion, optimizer, scheduler)
    print('Damaged images: ', damaged_images)
    utils.plot_metrics(hist_dict, args.results_folder, args.date)

if __name__ == '__main__':
    main()