|
from __future__ import print_function |
|
from __future__ import division |
|
import torch |
|
import torchvision |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import models |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from sklearn.utils import class_weight |
|
from sklearn.metrics import precision_recall_fscore_support |
|
import numpy as np |
|
import time |
|
import argparse |
|
from tqdm import tqdm |
|
from PIL import Image, ImageFile |
|
from pathlib import Path |
|
|
|
from augment import RandAug |
|
import utils |
|
|
|
print("PyTorch Version: ",torch.__version__) |
|
print("Torchvision Version: ",torchvision.__version__) |
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser('arguments for training') |
|
|
|
parser.add_argument('--tr_empty_folder', type=str, default="/path/to/empty/train/images", |
|
help='path to training data with empty images') |
|
parser.add_argument('--val_empty_folder', type=str, default="/path/to/empty/validation/images", |
|
help='path to validation data with empty images') |
|
parser.add_argument('--tr_ok_folder', type=str, default="/path/to/non-empty/train/images", |
|
help='path to training data with ok images') |
|
parser.add_argument('--val_ok_folder', type=str, default="/path/to/non-empty/validation/images", |
|
help='path to validation data with ok images') |
|
parser.add_argument('--results_folder', type=str, default="./results/", |
|
help='Folder for saving training results.') |
|
parser.add_argument('--save_model_path', type=str, default="./models/", |
|
help='Path for saving model file.') |
|
parser.add_argument('--batch_size', type=int, default=32, |
|
help='Batch size used for model training. ') |
|
parser.add_argument('--lr', type=float, default=0.0001, |
|
help='Base learning rate.') |
|
parser.add_argument('--device', type=str, default='cpu', |
|
help='Defines whether the model is trained using cpu or gpu.') |
|
parser.add_argument('--num_classes', type=int, default=2, |
|
help='Number of classes used in classification.') |
|
parser.add_argument('--num_epochs', type=int, default=15, |
|
help='Number of training epochs.') |
|
parser.add_argument('--random_seed', type=int, default=8765, |
|
help='Number used for initializing random number generation.') |
|
parser.add_argument('--early_stop_threshold', type=int, default=3, |
|
help='Threshold value of epochs after which training stops if validation accuracy does not improve.') |
|
parser.add_argument('--save_model_format', type=str, default='onnx', |
|
help='Defines the format for saving the model.') |
|
parser.add_argument('--augment_choice', type=str, default=None, |
|
help='Defines which image augmentation(s) are used. Defaults to randomly selected augmentations.') |
|
parser.add_argument('--model_name', type=str, default='test_model', |
|
help='Current date.') |
|
parser.add_argument('--date', type=str, default=time.strftime("%d%m%Y"), |
|
help='Current date.') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
Image.MAX_IMAGE_PIXELS = None |
|
|
|
|
|
damaged_images = [] |
|
|
|
def get_datapaths(): |
|
"""Function for loading train and validation data.""" |
|
tr_empty_files = list(Path(args.tr_empty_folder).glob('*')) |
|
tr_ok_files = list(Path(args.tr_ok_folder).glob('*')) |
|
val_empty_files = list(Path(args.val_empty_folder).glob('*')) |
|
val_ok_files = list(Path(args.val_ok_folder).glob('*')) |
|
|
|
tr_labels = np.concatenate((np.zeros(len(tr_empty_files)), np.ones(len(tr_ok_files)))) |
|
val_labels = np.concatenate((np.zeros(len(val_empty_files)), np.ones(len(val_ok_files)))) |
|
|
|
tr_files = tr_empty_files + tr_ok_files |
|
val_files = val_empty_files + val_ok_files |
|
|
|
print('\nTraining data with empty cells: ', len(tr_empty_files)) |
|
print('Training data without empty cells: ', len(tr_ok_files)) |
|
|
|
print('Validation data with empty cells: ', len(val_empty_files)) |
|
print('Validation data without empty cells: ', len(val_ok_files)) |
|
|
|
data_dict = {'tr_data': tr_files, 'tr_labels': tr_labels, |
|
'val_data': val_files, 'val_labels': val_labels} |
|
|
|
return data_dict |
|
|
|
class ImageDataset(Dataset): |
|
"""PyTorch Dataset class is used for generating training and validation datasets.""" |
|
def __init__(self, img_paths, img_labels, transform=None, target_transform=None): |
|
self.img_paths = img_paths |
|
self.img_labels = img_labels |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
|
|
def __len__(self): |
|
return len(self.img_labels) |
|
|
|
def __getitem__(self, idx): |
|
img_path = self.img_paths[idx] |
|
try: |
|
image = Image.open(img_path).convert('RGB') |
|
label = self.img_labels[idx] |
|
except: |
|
|
|
damaged_images.append(img_path) |
|
return None |
|
if self.transform: |
|
image = self.transform(image.convert("RGB")) |
|
if self.target_transform: |
|
label = self.target_transform(label) |
|
|
|
return image, label |
|
|
|
def initialize_model(): |
|
"""Function for initializing pretrained neural network model (DenseNet121).""" |
|
model_ft = models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1) |
|
num_ftrs = model_ft.classifier.in_features |
|
model_ft.classifier = nn.Linear(num_ftrs, args.num_classes) |
|
input_size = 224 |
|
|
|
return model_ft, input_size |
|
|
|
def collate_fn(batch): |
|
"""Helper function for creating data batches.""" |
|
batch = list(filter(lambda x: x is not None, batch)) |
|
|
|
return torch.utils.data.dataloader.default_collate(batch) |
|
|
|
def initialize_dataloaders(data_dict, input_size): |
|
"""Function for initializing datasets and dataloaders.""" |
|
|
|
train_dataset = ImageDataset(img_paths=data_dict['tr_data'], img_labels=data_dict['tr_labels'], transform=RandAug(input_size, args.augment_choice)) |
|
validation_dataset = ImageDataset(img_paths=data_dict['val_data'], img_labels=data_dict['val_labels'], transform=RandAug(input_size, 'identity')) |
|
|
|
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4) |
|
validation_dataloader = DataLoader(validation_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4) |
|
|
|
return {'train': train_dataloader, 'val': validation_dataloader} |
|
|
|
def get_criterion(data_dict): |
|
"""Function for generating class weights and initializing the loss function.""" |
|
y = np.asarray(data_dict['tr_labels']) |
|
|
|
|
|
class_weights=class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y), y=y) |
|
class_weights=torch.tensor(class_weights, dtype=torch.float).to(args.device) |
|
print('\nClass weights: ', class_weights.tolist()) |
|
|
|
criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='mean') |
|
|
|
return criterion |
|
|
|
def get_optimizer(model): |
|
"""Function for initializing the optimizer.""" |
|
|
|
|
|
params_1 = [param for name, param in model.named_parameters() |
|
if name not in ["classifier.weight", "classifier.bias"]] |
|
params_2 = model.classifier.parameters() |
|
|
|
|
|
params_to_update = [ |
|
{'params': params_1, 'lr': args.lr}, |
|
{'params': params_2, 'lr': args.lr * 10} |
|
] |
|
|
|
optimizer = torch.optim.Adam(params_to_update, args.lr) |
|
|
|
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True) |
|
|
|
return optimizer, scheduler |
|
|
|
def train_model(model, dataloaders, criterion, optimizer, scheduler=None): |
|
"""Function for model training and validation.""" |
|
since = time.time() |
|
|
|
tr_loss_history = [] |
|
tr_acc_history = [] |
|
tr_f1_history = [] |
|
val_loss_history = [] |
|
val_acc_history = [] |
|
val_f1_history = [] |
|
|
|
lr1_history = [] |
|
lr2_history = [] |
|
|
|
|
|
best_f1 = 0 |
|
best_epoch = 0 |
|
early_stop = False |
|
|
|
|
|
for epoch in tqdm(range(args.num_epochs)): |
|
|
|
lr1_history.append(optimizer.param_groups[0]["lr"]) |
|
lr2_history.append(optimizer.param_groups[1]["lr"]) |
|
|
|
print('Epoch {}/{}'.format(epoch+1, args.num_epochs)) |
|
print('-' * 10) |
|
|
|
|
|
for phase in ['train', 'val']: |
|
if phase == 'train': |
|
model.train() |
|
else: |
|
model.eval() |
|
|
|
running_loss = 0.0 |
|
running_corrects = 0 |
|
running_f1 = 0.0 |
|
|
|
|
|
for inputs, labels in dataloaders[phase]: |
|
if dataloaders[phase] is None: |
|
continue |
|
else: |
|
inputs = inputs.to(args.device) |
|
labels = labels.long().to(args.device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
with torch.set_grad_enabled(phase == 'train'): |
|
|
|
outputs = model(inputs) |
|
loss = criterion(outputs, labels) |
|
|
|
_, preds = torch.max(outputs, 1) |
|
|
|
|
|
if phase == 'train': |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
precision_recall_fscore = precision_recall_fscore_support(labels.data.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='weighted', zero_division=0) |
|
f1_score = precision_recall_fscore[2] |
|
|
|
|
|
running_loss += loss.item() * inputs.size(0) |
|
running_corrects += torch.sum(preds == labels.data).cpu() |
|
running_f1 += f1_score |
|
|
|
|
|
epoch_loss = running_loss / len(dataloaders[phase].dataset) |
|
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) |
|
epoch_f1 = running_f1 / len(dataloaders[phase]) |
|
|
|
print('\nEpoch {} - {} - Loss: {:.4f} Acc: {:.4f} F1: {:.4f}\n'.format(epoch+1, phase, epoch_loss, epoch_acc, epoch_f1)) |
|
|
|
|
|
if phase == 'val': |
|
val_acc_history.append(epoch_acc) |
|
val_loss_history.append(epoch_loss) |
|
val_f1_history.append(epoch_f1) |
|
if epoch_f1 > best_f1: |
|
print('\nF1 score {:.4f} improved from {:.4f}. Saving the model.\n'.format(epoch_f1, best_f1)) |
|
|
|
utils.save_model(model, 224, args.save_model_format, args.save_model_path, args.model_name, args.date) |
|
model = model.to(args.device) |
|
best_f1 = epoch_f1 |
|
best_epoch = epoch |
|
elif epoch - best_epoch > args.early_stop_threshold: |
|
|
|
print("Early stopped training at epoch %d" % epoch) |
|
|
|
early_stop = True |
|
break |
|
elif phase == 'train': |
|
tr_acc_history.append(epoch_acc) |
|
tr_loss_history.append(epoch_loss) |
|
tr_f1_history.append(epoch_f1) |
|
|
|
|
|
if early_stop: |
|
break |
|
|
|
if scheduler: |
|
scheduler.step(val_f1_history[-1]) |
|
|
|
time_elapsed = time.time() - since |
|
print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) |
|
print('Best validation F1 score: {:.4f}'.format(best_f1)) |
|
|
|
hist_dict = {'tr_acc': tr_acc_history, |
|
'val_acc': val_acc_history, |
|
'val_loss': val_loss_history, |
|
'val_f1': val_f1_history, |
|
'tr_loss': tr_loss_history, |
|
'tr_f1': tr_f1_history, |
|
'lr1': lr1_history, |
|
'lr2': lr2_history} |
|
|
|
return hist_dict |
|
|
|
def main(): |
|
|
|
utils.set_seed(args.random_seed) |
|
|
|
data_dict = get_datapaths() |
|
|
|
model, input_size = initialize_model() |
|
|
|
|
|
|
|
model = model.to(args.device) |
|
print("\nInitializing Datasets and Dataloaders...") |
|
dataloaders_dict = initialize_dataloaders(data_dict, input_size) |
|
criterion = get_criterion(data_dict) |
|
optimizer, scheduler = get_optimizer(model) |
|
|
|
hist_dict = train_model(model, dataloaders_dict, criterion, optimizer, scheduler) |
|
print('Damaged images: ', damaged_images) |
|
utils.plot_metrics(hist_dict, args.results_folder, args.date) |
|
|
|
if __name__ == '__main__': |
|
main() |