''' Author: Chris Xiao yl.xiao@mail.utoronto.ca Date: 2023-09-11 18:27:02 LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca LastEditTime: 2023-12-17 18:22:47 FilePath: /EndoSAM/endoSAM/train.py Description: fine-tune training script I Love IU Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved. ''' ''' @copyright Chris Xiao yl.xiao@mail.utoronto.ca ''' import argparse from omegaconf import OmegaConf from torch.utils.data import DataLoader import os from dataset import EndoVisDataset from utils import make_if_dont_exist, setup_logger, one_hot_embedding_3d, save_checkpoint, plot_progress import datetime import torch from model import EndoSAMAdapter import numpy as np from segment_anything.build_sam import sam_model_registry from loss import ce_loss, mse_loss from tqdm import tqdm import wget COMMON_MODEL_LINKS={ 'default': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth' } def parse_command(): parser = argparse.ArgumentParser() parser.add_argument('--cfg', default=None, type=str, help='path to config file') parser.add_argument('--resume', action='store_true', help='use this if you want to continue a training') args = parser.parse_args() return args if __name__ == '__main__': args = parse_command() cfg_path = args.cfg resume = args.resume device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if cfg_path is not None: if os.path.exists(cfg_path): cfg = OmegaConf.load(cfg_path) else: raise FileNotFoundError(f'config file {cfg_path} not found') else: raise ValueError('config file not specified') if 'sam_model_dir' not in OmegaConf.to_container(cfg)['model'].keys() or OmegaConf.is_missing(cfg.model, 'sam_model_dir') or not os.path.exists(cfg.model.sam_model_dir): print("Didn't find SAM Checkpoint. Downloading from Facebook AI...") parent_dir = '/'.join(os.getcwd().split('/')[:-1]) model_dir = os.path.join(parent_dir, 'sam_ckpts') make_if_dont_exist(model_dir, overwrite=True) checkpoint = os.path.join(model_dir, cfg.model.sam_model_type+'.pth') wget.download(COMMON_MODEL_LINKS[cfg.model.sam_model_type], checkpoint) OmegaConf.update(cfg, 'model.sam_model_dir', checkpoint) OmegaConf.save(cfg, cfg_path) exp = cfg.experiment_name root_dir = cfg.dataset.dataset_dir img_format = cfg.dataset.img_format ann_format = cfg.dataset.ann_format model_path = cfg.model_folder log_path = cfg.log_folder ckpt_path = cfg.ckpt_folder plot_path = cfg.plot_folder model_exp_path = os.path.join(model_path, exp) log_exp_path = os.path.join(log_path, exp) ckpt_exp_path = os.path.join(ckpt_path, exp) plot_exp_path = os.path.join(plot_path, exp) if not resume: make_if_dont_exist(model_path, overwrite=True) make_if_dont_exist(log_path, overwrite=True) make_if_dont_exist(ckpt_path, overwrite=True) make_if_dont_exist(plot_path, overwrite=True) make_if_dont_exist(model_exp_path, overwrite=True) make_if_dont_exist(log_exp_path, overwrite=True) make_if_dont_exist(ckpt_exp_path, overwrite=True) make_if_dont_exist(plot_exp_path, overwrite=True) datetime_object = 'training_log_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + '.log' logger = setup_logger(f'EndoSAM', os.path.join(log_exp_path, datetime_object)) logger.info(f"Welcome To {exp} Fine-Tuning") logger.info("Load Dataset-Specific Parameters") train_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='train', encoder_size=cfg.model.encoder_size) valid_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='val', encoder_size=cfg.model.encoder_size) train_loader = DataLoader(train_dataset, batch_size=cfg.train_bs, shuffle=True, num_workers=cfg.num_workers) valid_loader = DataLoader(valid_dataset, batch_size=cfg.val_bs, shuffle=True, num_workers=cfg.num_workers) logger.info("Load Model-Specific Parameters") sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder = sam_model_registry[cfg.model.sam_model_type](checkpoint=cfg.model.sam_model_dir,customized=cfg.model.sam_model_customized) model = EndoSAMAdapter(device, cfg.model.class_num, sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder, num_token=cfg.num_token).to(device) lr = cfg.opt_params.lr_default optimizer = torch.optim.Adam(model.parameters(), lr=lr) train_losses = [] val_losses = [] best_val_loss = np.inf max_iter = cfg.max_iter val_iter = cfg.val_iter start_epoch = 0 if resume: ckpt = torch.load(os.path.join(ckpt_exp_path, 'ckpt.pth'), map_location=device) optimizer.load_state_dict(ckpt['optimizer']) model.load_state_dict(ckpt['weights']) best_val_loss = ckpt['best_val_loss'] train_losses = ckpt['train_losses'] val_losses = ckpt['val_losses'] lr = optimizer.param_groups[0]['lr'] start_epoch = ckpt['epoch'] + 1 logger.info("Resume Training") else: logger.info("Start Training") for epoch in range(start_epoch, cfg.max_iter): logger.info(f"Epoch {epoch+1}/{cfg.max_iter}:") losses = [] model.train() with tqdm(train_loader, unit='batch', desc='Training') as tdata: for img, ann, _, _ in tdata: img = img.to(device) ann = ann.to(device).unsqueeze(1).long() ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num) optimizer.zero_grad() pred, pred_quality = model(img) loss = cfg.losses.ce.weight * ce_loss(ann, pred) + cfg.losses.mse.weight * mse_loss(ann, pred) tdata.set_postfix(loss=loss.item()) loss.backward() optimizer.step() losses.append(loss.item()) avg_loss = np.mean(losses, axis=0) logger.info(f"\ttraining loss: {avg_loss}") train_losses.append([epoch+1, avg_loss]) if epoch % cfg.val_iter == 0: model.eval() losses = [] with torch.no_grad(): with tqdm(valid_loader, unit='batch', desc='Validation') as tdata: for img, ann, _, _ in tdata: img = img.to(device) ann = ann.to(device).unsqueeze(1).long() ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num) pred, pred_quality = model(img) loss = cfg.losses.ce.weight * ce_loss(ann, pred) + cfg.losses.mse.weight * mse_loss(ann, pred) tdata.set_postfix(loss=loss.item()) losses.append(loss.item()) avg_loss = np.mean(losses, axis=0) logger.info(f"\tvalidation loss: {avg_loss}") val_losses.append([epoch+1, avg_loss]) if avg_loss < best_val_loss: best_val_loss = avg_loss logger.info(f"\tsave best endosam model") torch.save({ 'epoch': epoch, 'best_val_loss': best_val_loss, 'train_losses': train_losses, 'val_losses': val_losses, 'endosam_state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(model_exp_path, 'model.pth')) save_dir = os.path.join(ckpt_exp_path, 'ckpt.pth') save_checkpoint(model, optimizer, epoch, best_val_loss, train_losses, val_losses, save_dir) plot_progress(logger, plot_exp_path, train_losses, val_losses, 'loss')