Spaces:
Configuration error
Configuration error
# train.py | |
#!/usr/bin/env python3 | |
""" train network using pytorch | |
Junde Wu | |
""" | |
import argparse | |
import os | |
import sys | |
import time | |
from collections import OrderedDict | |
from datetime import datetime | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from skimage import io | |
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score | |
from tensorboardX import SummaryWriter | |
#from dataset import * | |
from torch.autograd import Variable | |
from torch.utils.data import DataLoader, random_split | |
from torch.utils.data.sampler import SubsetRandomSampler | |
from tqdm import tqdm | |
import cfg | |
import function | |
from conf import settings | |
#from models.discriminatorlayer import discriminator | |
from dataset import * | |
from utils import * | |
def main(): | |
args = cfg.parse_args() | |
GPUdevice = torch.device('cuda', args.gpu_device) | |
net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) | |
if args.pretrain: | |
weights = torch.load(args.pretrain) | |
net.load_state_dict(weights,strict=False) | |
optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) | |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay | |
'''load pretrained model''' | |
if args.weights != 0: | |
print(f'=> resuming from {args.weights}') | |
assert os.path.exists(args.weights) | |
checkpoint_file = os.path.join(args.weights) | |
assert os.path.exists(checkpoint_file) | |
loc = 'cuda:{}'.format(args.gpu_device) | |
checkpoint = torch.load(checkpoint_file, map_location=loc) | |
start_epoch = checkpoint['epoch'] | |
best_tol = checkpoint['best_tol'] | |
net.load_state_dict(checkpoint['state_dict'],strict=False) | |
# optimizer.load_state_dict(checkpoint['optimizer'], strict=False) | |
args.path_helper = checkpoint['path_helper'] | |
logger = create_logger(args.path_helper['log_path']) | |
print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') | |
args.path_helper = set_log_dir('logs', args.exp_name) | |
logger = create_logger(args.path_helper['log_path']) | |
logger.info(args) | |
nice_train_loader, nice_test_loader = get_dataloader(args) | |
'''checkpoint path and tensorboard''' | |
# iter_per_epoch = len(Glaucoma_training_loader) | |
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW) | |
#use tensorboard | |
if not os.path.exists(settings.LOG_DIR): | |
os.mkdir(settings.LOG_DIR) | |
writer = SummaryWriter(log_dir=os.path.join( | |
settings.LOG_DIR, args.net, settings.TIME_NOW)) | |
# input_tensor = torch.Tensor(args.b, 3, 256, 256).cuda(device = GPUdevice) | |
# writer.add_graph(net, Variable(input_tensor, requires_grad=True)) | |
#create checkpoint folder to save model | |
if not os.path.exists(checkpoint_path): | |
os.makedirs(checkpoint_path) | |
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') | |
'''begain training''' | |
best_acc = 0.0 | |
best_tol = 1e4 | |
best_dice = 0.0 | |
for epoch in range(settings.EPOCH): | |
if epoch and epoch < 5: | |
if args.dataset != 'REFUGE': | |
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) | |
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') | |
else: | |
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer) | |
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.') | |
net.train() | |
time_start = time.time() | |
loss = function.train_sam(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis) | |
logger.info(f'Train loss: {loss} || @ epoch {epoch}.') | |
time_end = time.time() | |
print('time_for_training ', time_end - time_start) | |
net.eval() | |
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1: | |
if args.dataset != 'REFUGE': | |
tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer) | |
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') | |
else: | |
tol, (eiou_cup, eiou_disc, edice_cup, edice_disc) = function.validation_sam(args, nice_test_loader, epoch, net, writer) | |
logger.info(f'Total score: {tol}, IOU_CUP: {eiou_cup}, IOU_DISC: {eiou_disc}, DICE_CUP: {edice_cup}, DICE_DISC: {edice_disc} || @ epoch {epoch}.') | |
if args.distributed != 'none': | |
sd = net.module.state_dict() | |
else: | |
sd = net.state_dict() | |
if edice > best_dice: | |
best_tol = tol | |
is_best = True | |
save_checkpoint({ | |
'epoch': epoch + 1, | |
'model': args.net, | |
'state_dict': sd, | |
'optimizer': optimizer.state_dict(), | |
'best_tol': best_dice, | |
'path_helper': args.path_helper, | |
}, is_best, args.path_helper['ckpt_path'], filename="best_dice_checkpoint.pth") | |
else: | |
is_best = False | |
writer.close() | |
if __name__ == '__main__': | |
main() | |