|
''' |
|
Author: Chris Xiao [email protected] |
|
Date: 2023-09-11 18:27:02 |
|
LastEditors: Chris Xiao [email protected] |
|
LastEditTime: 2023-12-17 18:22:47 |
|
FilePath: /EndoSAM/endoSAM/train.py |
|
Description: fine-tune training script |
|
I Love IU |
|
Copyright (c) 2023 by Chris Xiao [email protected], All Rights Reserved. |
|
''' |
|
''' |
|
@copyright Chris Xiao [email protected] |
|
''' |
|
import argparse |
|
from omegaconf import OmegaConf |
|
from torch.utils.data import DataLoader |
|
import os |
|
from dataset import EndoVisDataset |
|
from utils import make_if_dont_exist, setup_logger, one_hot_embedding_3d, save_checkpoint, plot_progress |
|
import datetime |
|
import torch |
|
from model import EndoSAMAdapter |
|
import numpy as np |
|
from segment_anything.build_sam import sam_model_registry |
|
from loss import ce_loss, mse_loss |
|
from tqdm import tqdm |
|
import wget |
|
|
|
COMMON_MODEL_LINKS={ |
|
'default': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', |
|
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', |
|
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', |
|
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth' |
|
} |
|
|
|
def parse_command(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--cfg', default=None, type=str, help='path to config file') |
|
parser.add_argument('--resume', action='store_true', help='use this if you want to continue a training') |
|
args = parser.parse_args() |
|
return args |
|
|
|
if __name__ == '__main__': |
|
args = parse_command() |
|
cfg_path = args.cfg |
|
resume = args.resume |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
if cfg_path is not None: |
|
if os.path.exists(cfg_path): |
|
cfg = OmegaConf.load(cfg_path) |
|
else: |
|
raise FileNotFoundError(f'config file {cfg_path} not found') |
|
else: |
|
raise ValueError('config file not specified') |
|
|
|
|
|
if 'sam_model_dir' not in OmegaConf.to_container(cfg)['model'].keys() or OmegaConf.is_missing(cfg.model, 'sam_model_dir') or not os.path.exists(cfg.model.sam_model_dir): |
|
print("Didn't find SAM Checkpoint. Downloading from Facebook AI...") |
|
parent_dir = '/'.join(os.getcwd().split('/')[:-1]) |
|
model_dir = os.path.join(parent_dir, 'sam_ckpts') |
|
make_if_dont_exist(model_dir, overwrite=True) |
|
checkpoint = os.path.join(model_dir, cfg.model.sam_model_type+'.pth') |
|
wget.download(COMMON_MODEL_LINKS[cfg.model.sam_model_type], checkpoint) |
|
OmegaConf.update(cfg, 'model.sam_model_dir', checkpoint) |
|
OmegaConf.save(cfg, cfg_path) |
|
|
|
exp = cfg.experiment_name |
|
root_dir = cfg.dataset.dataset_dir |
|
img_format = cfg.dataset.img_format |
|
ann_format = cfg.dataset.ann_format |
|
model_path = cfg.model_folder |
|
log_path = cfg.log_folder |
|
ckpt_path = cfg.ckpt_folder |
|
plot_path = cfg.plot_folder |
|
model_exp_path = os.path.join(model_path, exp) |
|
log_exp_path = os.path.join(log_path, exp) |
|
ckpt_exp_path = os.path.join(ckpt_path, exp) |
|
plot_exp_path = os.path.join(plot_path, exp) |
|
|
|
if not resume: |
|
make_if_dont_exist(model_path, overwrite=True) |
|
make_if_dont_exist(log_path, overwrite=True) |
|
make_if_dont_exist(ckpt_path, overwrite=True) |
|
make_if_dont_exist(plot_path, overwrite=True) |
|
make_if_dont_exist(model_exp_path, overwrite=True) |
|
make_if_dont_exist(log_exp_path, overwrite=True) |
|
make_if_dont_exist(ckpt_exp_path, overwrite=True) |
|
make_if_dont_exist(plot_exp_path, overwrite=True) |
|
|
|
datetime_object = 'training_log_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + '.log' |
|
logger = setup_logger(f'EndoSAM', os.path.join(log_exp_path, datetime_object)) |
|
logger.info(f"Welcome To {exp} Fine-Tuning") |
|
|
|
logger.info("Load Dataset-Specific Parameters") |
|
train_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='train', encoder_size=cfg.model.encoder_size) |
|
valid_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='val', encoder_size=cfg.model.encoder_size) |
|
train_loader = DataLoader(train_dataset, batch_size=cfg.train_bs, shuffle=True, num_workers=cfg.num_workers) |
|
valid_loader = DataLoader(valid_dataset, batch_size=cfg.val_bs, shuffle=True, num_workers=cfg.num_workers) |
|
|
|
logger.info("Load Model-Specific Parameters") |
|
sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder = sam_model_registry[cfg.model.sam_model_type](checkpoint=cfg.model.sam_model_dir,customized=cfg.model.sam_model_customized) |
|
model = EndoSAMAdapter(device, cfg.model.class_num, sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder, num_token=cfg.num_token).to(device) |
|
lr = cfg.opt_params.lr_default |
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
|
train_losses = [] |
|
val_losses = [] |
|
best_val_loss = np.inf |
|
max_iter = cfg.max_iter |
|
val_iter = cfg.val_iter |
|
start_epoch = 0 |
|
if resume: |
|
ckpt = torch.load(os.path.join(ckpt_exp_path, 'ckpt.pth'), map_location=device) |
|
optimizer.load_state_dict(ckpt['optimizer']) |
|
model.load_state_dict(ckpt['weights']) |
|
best_val_loss = ckpt['best_val_loss'] |
|
train_losses = ckpt['train_losses'] |
|
val_losses = ckpt['val_losses'] |
|
lr = optimizer.param_groups[0]['lr'] |
|
start_epoch = ckpt['epoch'] + 1 |
|
logger.info("Resume Training") |
|
else: |
|
logger.info("Start Training") |
|
|
|
for epoch in range(start_epoch, cfg.max_iter): |
|
logger.info(f"Epoch {epoch+1}/{cfg.max_iter}:") |
|
losses = [] |
|
model.train() |
|
with tqdm(train_loader, unit='batch', desc='Training') as tdata: |
|
for img, ann, _, _ in tdata: |
|
img = img.to(device) |
|
ann = ann.to(device).unsqueeze(1).long() |
|
ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num) |
|
optimizer.zero_grad() |
|
pred, pred_quality = model(img) |
|
loss = cfg.losses.ce.weight * ce_loss(ann, pred) + cfg.losses.mse.weight * mse_loss(ann, pred) |
|
tdata.set_postfix(loss=loss.item()) |
|
loss.backward() |
|
optimizer.step() |
|
losses.append(loss.item()) |
|
|
|
avg_loss = np.mean(losses, axis=0) |
|
logger.info(f"\ttraining loss: {avg_loss}") |
|
train_losses.append([epoch+1, avg_loss]) |
|
|
|
if epoch % cfg.val_iter == 0: |
|
model.eval() |
|
losses = [] |
|
with torch.no_grad(): |
|
with tqdm(valid_loader, unit='batch', desc='Validation') as tdata: |
|
for img, ann, _, _ in tdata: |
|
img = img.to(device) |
|
ann = ann.to(device).unsqueeze(1).long() |
|
ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num) |
|
pred, pred_quality = model(img) |
|
loss = cfg.losses.ce.weight * ce_loss(ann, pred) + cfg.losses.mse.weight * mse_loss(ann, pred) |
|
tdata.set_postfix(loss=loss.item()) |
|
losses.append(loss.item()) |
|
|
|
avg_loss = np.mean(losses, axis=0) |
|
logger.info(f"\tvalidation loss: {avg_loss}") |
|
val_losses.append([epoch+1, avg_loss]) |
|
if avg_loss < best_val_loss: |
|
best_val_loss = avg_loss |
|
logger.info(f"\tsave best endosam model") |
|
torch.save({ |
|
'epoch': epoch, |
|
'best_val_loss': best_val_loss, |
|
'train_losses': train_losses, |
|
'val_losses': val_losses, |
|
'endosam_state_dict': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
}, os.path.join(model_exp_path, 'model.pth')) |
|
save_dir = os.path.join(ckpt_exp_path, 'ckpt.pth') |
|
save_checkpoint(model, optimizer, epoch, best_val_loss, train_losses, val_losses, save_dir) |
|
plot_progress(logger, plot_exp_path, train_losses, val_losses, 'loss') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|