WSCL / main.py
yhzhai's picture
release code
482ab8a
import datetime
import math
import os
from functools import partial
import albumentations as A
import torch.optim as optim
from termcolor import cprint
from timm.scheduler import create_scheduler
from torch.utils.data import DataLoader
import utils.misc as misc
from datasets import crop_to_smallest_collate_fn, get_dataset
from engine import bundled_evaluate, train
from losses import get_bundled_loss, get_loss
from models import get_ensemble_model, get_single_modal_model
from opt import get_opt
def main(opt):
# get tensorboard writer
writer = misc.setup_env(opt)
# dataset
# training sets
train_loaders = {}
if not opt.eval:
train_transform = A.Compose(
[
A.HorizontalFlip(0.5),
A.SmallestMaxSize(int(opt.input_size * 1.5))
if opt.resize_aug
else A.NoOp(),
A.RandomSizedCrop(
(opt.input_size, int(opt.input_size * 1.5)),
opt.input_size,
opt.input_size,
)
if opt.resize_aug
else A.NoOp(),
A.NoOp() if opt.no_gaussian_blur else A.GaussianBlur(p=0.5),
A.NoOp() if opt.no_color_jitter else A.ColorJitter(p=0.5),
A.NoOp() if opt.no_jpeg_compression else A.ImageCompression(p=0.5),
]
)
train_sets = get_dataset(opt.train_datalist, "train", train_transform, opt)
for k, dataset in train_sets.items():
train_loaders[k] = DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=True,
pin_memory=True,
num_workers=0 if opt.debug else opt.num_workers,
collate_fn=partial(
crop_to_smallest_collate_fn,
max_size=opt.input_size,
uncorrect_label=opt.uncorrect_label,
),
)
# validation sets
if opt.large_image_strategy == "rescale":
val_transform = A.Compose([A.SmallestMaxSize(opt.tile_size)])
else:
val_transform = None
val_sets = get_dataset(opt.val_datalist, opt.val_set, val_transform, opt)
val_loaders = {}
for k, dataset in val_sets.items():
val_loaders[k] = DataLoader(
dataset,
batch_size=1,
shuffle=opt.val_shuffle,
pin_memory=True,
num_workers=0 if opt.debug else opt.num_workers,
)
# multi-view models and optimizers
optimizer_dict = {}
scheduler_dict = {}
model = get_ensemble_model(opt).to(opt.device)
n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(
f"Number of total params: {n_param}, num params per model: {int(n_param / len(opt.modality))}"
)
# optimizer and scheduler
for modality in opt.modality:
if opt.optimizer.lower() == "adamw":
optimizer = optim.AdamW(
model.sub_models[modality].parameters(),
opt.lr,
weight_decay=opt.weight_decay,
)
elif opt.optimizer.lower() == "sgd":
optimizer = optim.SGD(
model.sub_models[modality].parameters(),
opt.lr,
opt.momentum,
weight_decay=opt.weight_decay,
)
else:
raise RuntimeError(f"Unsupported optimizer {opt.optimizer}.")
scheduler, num_epoch = create_scheduler(opt, optimizer)
optimizer_dict[modality] = optimizer
scheduler_dict[modality] = scheduler
opt.epochs = num_epoch
# loss functions
# loss function including the multi-view consistency loss, for training
bundled_criterion = get_bundled_loss(opt).to(opt.device)
# loss function excluding the multi-view consistency loss, for evaluation
single_criterion = get_loss(opt).to(opt.device)
if opt.resume:
misc.resume_from(model, opt.resume)
if opt.eval:
bundled_evaluate(
model, val_loaders, single_criterion, 0, writer, suffix="val", opt=opt
)
return
cprint("The training will last for {} epochs.".format(opt.epochs), "blue")
best_ensemble_image_f1 = -math.inf
for epoch in range(opt.epochs):
for title, dataloader in train_loaders.items():
train(
model,
dataloader,
title,
optimizer_dict,
bundled_criterion,
epoch,
writer,
suffix="train",
opt=opt,
)
for sched_idx, scheduler in enumerate(scheduler_dict.values()):
if sched_idx == 0 and writer is not None:
writer.add_scalar("lr", scheduler._get_lr(epoch)[0], epoch)
scheduler.step(epoch)
if (epoch + 1) % opt.eval_freq == 0 or epoch in [opt.epochs - 1]:
result = bundled_evaluate(
model,
val_loaders,
single_criterion,
epoch,
writer,
suffix="val",
opt=opt,
)
misc.save_model(
os.path.join(
opt.save_root_path, opt.dir_name, "checkpoint", f"{epoch}.pt"
),
model,
epoch,
opt,
performance=result,
)
if result["image_f1/AVG_ensemble"] > best_ensemble_image_f1:
best_ensemble_image_f1 = result["image_f1/AVG_ensemble"]
misc.save_model(
os.path.join(
opt.save_root_path, opt.dir_name, "checkpoint", "best.pt"
),
model,
epoch,
opt,
performance=result,
)
misc.update_record(result, epoch, opt, "best_record")
misc.update_record(result, epoch, opt, "latest_record")
print("best performance:", best_ensemble_image_f1)
if __name__ == "__main__":
opt = get_opt()
# import cProfile
# import pstats
# profiler = cProfile.Profile()
# profiler.enable()
st = datetime.datetime.now()
main(opt)
total_time = datetime.datetime.now() - st
total_time = str(datetime.timedelta(seconds=total_time.seconds))
print(f"Total time: {total_time}")
print("finished")
# profiler.disable()
# stats = pstats.Stats(profiler).sort_stats('cumtime')
# stats.strip_dirs()
# stats_name = f'cprofile-data{opt.suffix}'
# if not opt.debug and not opt.eval:
# stats_name = os.path.join(opt.save_root_path, opt.dir_name, stats_name)
# else:
# stats_name = os.path.join('tmp', stats_name)
# stats.dump_stats(stats_name)