MikkoLipsanen's picture
Update train.py
9377114 verified
raw
history blame contribute delete
No virus
14.7 kB
from __future__ import print_function
from __future__ import division
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.utils import class_weight
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import time
import argparse
from tqdm import tqdm
from PIL import Image, ImageFile
from pathlib import Path
from augment import RandAug
import utils
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
# Much of the code is a modified version of the code available at
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
parser = argparse.ArgumentParser('arguments for training')
parser.add_argument('--tr_empty_folder', type=str, default="/path/to/empty/train/images",
help='path to training data with empty images')
parser.add_argument('--val_empty_folder', type=str, default="/path/to/empty/validation/images",
help='path to validation data with empty images')
parser.add_argument('--tr_ok_folder', type=str, default="/path/to/non-empty/train/images",
help='path to training data with ok images')
parser.add_argument('--val_ok_folder', type=str, default="/path/to/non-empty/validation/images",
help='path to validation data with ok images')
parser.add_argument('--results_folder', type=str, default="./results/",
help='Folder for saving training results.')
parser.add_argument('--save_model_path', type=str, default="./models/",
help='Path for saving model file.')
parser.add_argument('--batch_size', type=int, default=32,
help='Batch size used for model training. ')
parser.add_argument('--lr', type=float, default=0.0001,
help='Base learning rate.')
parser.add_argument('--device', type=str, default='cpu',
help='Defines whether the model is trained using cpu or gpu.')
parser.add_argument('--num_classes', type=int, default=2,
help='Number of classes used in classification.')
parser.add_argument('--num_epochs', type=int, default=15,
help='Number of training epochs.')
parser.add_argument('--random_seed', type=int, default=8765,
help='Number used for initializing random number generation.')
parser.add_argument('--early_stop_threshold', type=int, default=3,
help='Threshold value of epochs after which training stops if validation accuracy does not improve.')
parser.add_argument('--save_model_format', type=str, default='onnx',
help='Defines the format for saving the model.')
parser.add_argument('--augment_choice', type=str, default=None,
help='Defines which image augmentation(s) are used. Defaults to randomly selected augmentations.')
parser.add_argument('--model_name', type=str, default='test_model',
help='Current date.')
parser.add_argument('--date', type=str, default=time.strftime("%d%m%Y"),
help='Current date.')
args = parser.parse_args()
# PIL settings to avoid errors caused by truncated and large images
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
# List for saving the names of damaged images
damaged_images = []
def get_datapaths():
"""Function for loading train and validation data."""
tr_empty_files = list(Path(args.tr_empty_folder).glob('*'))
tr_ok_files = list(Path(args.tr_ok_folder).glob('*'))
val_empty_files = list(Path(args.val_empty_folder).glob('*'))
val_ok_files = list(Path(args.val_ok_folder).glob('*'))
# Create labels for train and validation data
tr_labels = np.concatenate((np.zeros(len(tr_empty_files)), np.ones(len(tr_ok_files))))
val_labels = np.concatenate((np.zeros(len(val_empty_files)), np.ones(len(val_ok_files))))
# Combine faulty and non-faulty images
tr_files = tr_empty_files + tr_ok_files
val_files = val_empty_files + val_ok_files
print('\nTraining data with empty cells: ', len(tr_empty_files))
print('Training data without empty cells: ', len(tr_ok_files))
print('Validation data with empty cells: ', len(val_empty_files))
print('Validation data without empty cells: ', len(val_ok_files))
data_dict = {'tr_data': tr_files, 'tr_labels': tr_labels,
'val_data': val_files, 'val_labels': val_labels}
return data_dict
class ImageDataset(Dataset):
"""PyTorch Dataset class is used for generating training and validation datasets."""
def __init__(self, img_paths, img_labels, transform=None, target_transform=None):
self.img_paths = img_paths
self.img_labels = img_labels
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = self.img_paths[idx]
try:
image = Image.open(img_path).convert('RGB')
label = self.img_labels[idx]
except:
# Image is considered damaged if reading the image fails
damaged_images.append(img_path)
return None
if self.transform:
image = self.transform(image.convert("RGB"))
if self.target_transform:
label = self.target_transform(label)
return image, label
def initialize_model():
"""Function for initializing pretrained neural network model (DenseNet121)."""
model_ft = models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, args.num_classes)
input_size = 224
return model_ft, input_size
def collate_fn(batch):
"""Helper function for creating data batches."""
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
def initialize_dataloaders(data_dict, input_size):
"""Function for initializing datasets and dataloaders."""
# Train and validation datasets
train_dataset = ImageDataset(img_paths=data_dict['tr_data'], img_labels=data_dict['tr_labels'], transform=RandAug(input_size, args.augment_choice))
validation_dataset = ImageDataset(img_paths=data_dict['val_data'], img_labels=data_dict['val_labels'], transform=RandAug(input_size, 'identity'))
# Train and validation dataloaders
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
validation_dataloader = DataLoader(validation_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
return {'train': train_dataloader, 'val': validation_dataloader}
def get_criterion(data_dict):
"""Function for generating class weights and initializing the loss function."""
y = np.asarray(data_dict['tr_labels'])
# Class weights are used for compensating the unbalance
# in the number of training data from the two classes
class_weights=class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y), y=y)
class_weights=torch.tensor(class_weights, dtype=torch.float).to(args.device)
print('\nClass weights: ', class_weights.tolist())
# Cross Entropy Loss function
criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
return criterion
def get_optimizer(model):
"""Function for initializing the optimizer."""
# Model parameters are split into two groups: parameters of the classifier
# layer and other model parameters
params_1 = [param for name, param in model.named_parameters()
if name not in ["classifier.weight", "classifier.bias"]]
params_2 = model.classifier.parameters()
# 10 x larger learning rate is used when training the parameters
# of the classification layers
params_to_update = [
{'params': params_1, 'lr': args.lr},
{'params': params_2, 'lr': args.lr * 10}
]
# Adam optimizer
optimizer = torch.optim.Adam(params_to_update, args.lr)
# Scheduler reduces learning rate when validation accuracy does not improve for an epoch
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True)
return optimizer, scheduler
def train_model(model, dataloaders, criterion, optimizer, scheduler=None):
"""Function for model training and validation."""
since = time.time()
# Lists for saving train and validation metrics for each epoch
tr_loss_history = []
tr_acc_history = []
tr_f1_history = []
val_loss_history = []
val_acc_history = []
val_f1_history = []
# Lists for saving learning rates for the 2 parameter groups
lr1_history = []
lr2_history = []
# Best F1 value and best epoch are saved in variables
best_f1 = 0
best_epoch = 0
early_stop = False
# Train / validation loop
for epoch in tqdm(range(args.num_epochs)):
# Save learning rates for the epoch
lr1_history.append(optimizer.param_groups[0]["lr"])
lr2_history.append(optimizer.param_groups[1]["lr"])
print('Epoch {}/{}'.format(epoch+1, args.num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
running_f1 = 0.0
# Iterate over data in batch
for inputs, labels in dataloaders[phase]:
if dataloaders[phase] is None:
continue
else:
inputs = inputs.to(args.device)
labels = labels.long().to(args.device)
# Zero the parameter gradients
optimizer.zero_grad()
# Track history only in training phase
with torch.set_grad_enabled(phase == 'train'):
# Get model outputs and calculate loss
outputs = model(inputs)
loss = criterion(outputs, labels)
# Model predictions of the image labels for the batch
_, preds = torch.max(outputs, 1)
# Backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# Get weighted F1 score for the results
precision_recall_fscore = precision_recall_fscore_support(labels.data.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='weighted', zero_division=0)
f1_score = precision_recall_fscore[2]
# update statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data).cpu()
running_f1 += f1_score
# Calculate loss, accuracy and F1 score for the epoch
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
epoch_f1 = running_f1 / len(dataloaders[phase])
print('\nEpoch {} - {} - Loss: {:.4f} Acc: {:.4f} F1: {:.4f}\n'.format(epoch+1, phase, epoch_loss, epoch_acc, epoch_f1))
# Validation step
if phase == 'val':
val_acc_history.append(epoch_acc)
val_loss_history.append(epoch_loss)
val_f1_history.append(epoch_f1)
if epoch_f1 > best_f1:
print('\nF1 score {:.4f} improved from {:.4f}. Saving the model.\n'.format(epoch_f1, best_f1))
# Model with best F1 score is saved
utils.save_model(model, 224, args.save_model_format, args.save_model_path, args.model_name, args.date)
model = model.to(args.device)
best_f1 = epoch_f1
best_epoch = epoch
elif epoch - best_epoch > args.early_stop_threshold:
# terminates the training loop if validation accuracy has not improved
print("Early stopped training at epoch %d" % epoch)
# Set early stopping condition
early_stop = True
break
elif phase == 'train':
tr_acc_history.append(epoch_acc)
tr_loss_history.append(epoch_loss)
tr_f1_history.append(epoch_f1)
# Break outer loop if early stopping condition is activated
if early_stop:
break
# Take scheduler step
if scheduler:
scheduler.step(val_f1_history[-1])
time_elapsed = time.time() - since
print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best validation F1 score: {:.4f}'.format(best_f1))
# Returns model with the weights from the best epoch (based on validation accuracy)
hist_dict = {'tr_acc': tr_acc_history,
'val_acc': val_acc_history,
'val_loss': val_loss_history,
'val_f1': val_f1_history,
'tr_loss': tr_loss_history,
'tr_f1': tr_f1_history,
'lr1': lr1_history,
'lr2': lr2_history}
return hist_dict
def main():
# Set random seed(s)
utils.set_seed(args.random_seed)
# Load image paths and labels
data_dict = get_datapaths()
# Initialize the model
model, input_size = initialize_model()
# Print the model architecture
#print(model_ft)
# Send the model to GPU (if available)
model = model.to(args.device)
print("\nInitializing Datasets and Dataloaders...")
dataloaders_dict = initialize_dataloaders(data_dict, input_size)
criterion = get_criterion(data_dict)
optimizer, scheduler = get_optimizer(model)
# Train and evaluate model
hist_dict = train_model(model, dataloaders_dict, criterion, optimizer, scheduler)
print('Damaged images: ', damaged_images)
utils.plot_metrics(hist_dict, args.results_folder, args.date)
if __name__ == '__main__':
main()