|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from advanced_config import AdvancedConfig |
|
from advanced_logger import AdvancedLogger, LogPriority |
|
import os |
|
from os.path import join |
|
from plot_froc import plot_froc |
|
from train_frcnn import main as TRAIN_FRCNN |
|
from train_bilateral import main as TRAIN_BILATERAL |
|
import torch |
|
from model_utils import generate_predictions, generate_predictions_bilateral |
|
import argparse |
|
from dataloaders import get_dict |
|
from utils import create_backup |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
class Experimenter: |
|
|
|
def __init__(self, cfg_file, BASE_DIR = 'experiments'): |
|
self.cfg_file = cfg_file |
|
|
|
self.con = AdvancedConfig(cfg_file) |
|
self.config = self.con.config |
|
self.exp_dir = join(BASE_DIR,self.config['EXP_NAME']) |
|
os.makedirs(self.exp_dir, exist_ok=True) |
|
self.con.save(join(self.exp_dir,'config.cfg')) |
|
|
|
self.logger = AdvancedLogger(self.exp_dir) |
|
self.logger.log('Experiment:',self.config['EXP_NAME'],priority = LogPriority.STATS) |
|
self.logger.log('Experiment Description:', self.config['EXP_DESC'], priority = LogPriority.STATS) |
|
self.logger.log('Config File:',self.cfg_file, priority = LogPriority.STATS) |
|
self.logger.log('Experiment started', priority = LogPriority.LOW) |
|
self.losses = dict() |
|
self.frocs = dict() |
|
|
|
self.writer = SummaryWriter(join(self.exp_dir,'tensor_logs')) |
|
|
|
create_backup(backup_dir=join(self.exp_dir,'scripts')) |
|
|
|
def log(self, *args, **kwargs): |
|
self.logger.log(*args, **kwargs) |
|
|
|
|
|
def init_losses(self,mode): |
|
if mode == 'FRCNN' or mode == 'FRCNN_BILATERAL': |
|
self.losses['frcnn_loss'] = [] |
|
self.frocs['frcnn_froc'] = [] |
|
elif mode == 'BILATERAL' or mode == 'FRCNN_BILATERAL': |
|
self.losses['bilateral_loss'] = [] |
|
self.frocs['bilateral_froc'] = [] |
|
|
|
def start_epoch(self): |
|
self.curr_epoch += 1 |
|
self.logger.log('Epoch:',self.curr_epoch, priority = LogPriority.MEDIUM) |
|
|
|
def end_epoch(self, loss, model = None, device = None): |
|
if self.curr_mode == 'FRCNN': |
|
self.losses['frcnn_loss'].append(loss) |
|
self.best_loss = min(self.losses['frcnn_loss']) |
|
if self.config['EVAL_METHOD'] == 'FROC': |
|
exp_name = self.config['EXP_NAME'] |
|
_, val_path, _ = self.init_paths() |
|
generate_predictions(model,device,val_path,f'preds_frcnn_{exp_name}') |
|
from froc_by_pranjal import get_froc_points |
|
senses, _ = get_froc_points(f'preds_frcnn_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']), fps_req = [0.2]) |
|
self.frocs['frcnn_froc'].append(senses[0]) |
|
self.best_froc = max(self.frocs['frcnn_froc']) |
|
self.logger.log(f'Val FROC: {senses[0]}', LogPriority.MEDIUM) |
|
self.logger.log(f'Best FROC: {self.best_froc}') |
|
elif self.curr_mode == 'BILATERAL': |
|
self.losses['bilateral_loss'].append(loss) |
|
self.best_loss = min(self.losses['bilateral_loss']) |
|
if self.config['EVAL_METHOD'] == 'FROC': |
|
exp_name = self.config['EXP_NAME'] |
|
_, val_path, _ = self.init_paths() |
|
data_dir = self.config['DATA_DIR'] |
|
print('Generating') |
|
generate_predictions_bilateral(model,device,val_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),preds_folder = f'preds_bilateral_{exp_name}') |
|
print('Generation Done') |
|
from froc_by_pranjal import get_froc_points |
|
senses, _ = get_froc_points(f'preds_bilateral_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']), fps_req = [0.1]) |
|
print('Reading Sens from',f'preds_bilateral_{exp_name}', join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']),) |
|
|
|
self.frocs['bilateral_froc'].append(senses[0]) |
|
self.best_froc = max(self.frocs['bilateral_froc']) |
|
self.logger.log(f'Val FROC: {senses[0]}', priority = LogPriority.MEDIUM) |
|
self.logger.log(f'Best FROC: {self.best_froc}') |
|
|
|
self.writer.add_scalar(f"{self.curr_mode}/Loss/Valid", loss, self.curr_epoch) |
|
|
|
|
|
|
|
def save_model(self, model): |
|
if self.curr_mode == 'FRCNN': |
|
self.logger.log('Saving FRCNN Model', priority = LogPriority.LOW) |
|
model_file = join(self.exp_dir,'frcnn_models',f'frcnn_model.pth') |
|
if self.config['EVAL_METHOD']: |
|
SAVE = self.best_froc == self.frocs['frcnn_froc'][-1] |
|
else: |
|
SAVE = self.best_loss == self.losses['frcnn_loss'][-1] |
|
elif self.curr_mode == 'BILATERAL': |
|
self.logger.log('Saving Bilateral Model', priority = LogPriority.LOW) |
|
model_file = join(self.exp_dir,'bilateral_models',f'bilateral_model.pth') |
|
if self.config['EVAL_METHOD'] == 'FROC': |
|
SAVE = self.best_froc == self.frocs['bilateral_froc'][-1] |
|
else: |
|
SAVE = self.best_loss == self.losses['bilateral_loss'][-1] |
|
os.makedirs(os.path.split(model_file)[0], exist_ok=True) |
|
if SAVE: |
|
torch.save(model.state_dict(), model_file) |
|
|
|
torch.save(model.state_dict(), f'{model_file[:-4]}_{self.curr_epoch}.pth') |
|
|
|
def init_paths(self,): |
|
train_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TRAIN_SPLIT']) |
|
val_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']) |
|
test_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT']) |
|
return train_path, val_path, test_path |
|
|
|
def abs_path(self, path): |
|
return join(self.config['DATA_DIR'], path) |
|
|
|
|
|
def generate_predictions(self,model, device): |
|
self.logger.log('Generating Predictions') |
|
self.logger.flush() |
|
exp_name = self.config['EXP_NAME'] |
|
train_path, val_path, test_path = self.init_paths() |
|
|
|
|
|
if self.curr_mode == 'FRCNN': |
|
model_file = join(self.exp_dir,'frcnn_models','frcnn_model.pth') |
|
elif self.curr_mode == 'BILATERAL': |
|
model_file = join(self.exp_dir,'bilateral_models','bilateral_model.pth') |
|
model.load_state_dict(torch.load(model_file)) |
|
|
|
if self.curr_mode == 'FRCNN': |
|
generate_predictions(model,device,train_path,f'preds_frcnn_{exp_name}') |
|
generate_predictions(model,device,val_path,f'preds_frcnn_{exp_name}') |
|
generate_predictions(model,device,test_path,f'preds_frcnn_{exp_name}') |
|
elif self.curr_mode == 'BILATERAL': |
|
data_dir = self.config['DATA_DIR'] |
|
generate_predictions_bilateral(model,device,train_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}') |
|
generate_predictions_bilateral(model,device,val_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}') |
|
generate_predictions_bilateral(model,device,test_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}') |
|
test_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT']) |
|
|
|
def run_experiment(self): |
|
|
|
|
|
mode = self.config['MODE'] |
|
self.init_losses(mode) |
|
self.curr_mode = 'FRCNN' |
|
self.curr_epoch = -1 |
|
self.best_loss = 999999 |
|
self.best_froc = 0 |
|
if mode == 'FRCNN': |
|
TRAIN_FRCNN(self.config['FRCNN'], self) |
|
elif mode == 'BILATERAL': |
|
self.curr_mode = 'BILATERAL' |
|
TRAIN_BILATERAL(self.config['BILATERAL'], self) |
|
elif mode == 'FRCNN_BILATERAL': |
|
TRAIN_FRCNN(self.config['FRCNN'], self) |
|
self.curr_mode = 'BILATERAL' |
|
self.curr_epoch = -1 |
|
self.best_loss = 999999 |
|
|
|
self.config['BILATERAL']['FRCNN_MODEL_PATH'] = join(self.exp_dir,'frcnn_models','frcnn_model.pth') |
|
TRAIN_BILATERAL(self.config['BILATERAL'], self) |
|
|
|
self.logger.log(f'Best Loss: {self.best_loss}', priority= LogPriority.STATS) |
|
self.logger.log('Experiment Training and Generation Ended', priority = LogPriority.MEDIUM) |
|
|
|
|
|
|
|
frcnn_file = join(self.exp_dir, 'senses_fps_frcnn.txt') |
|
bilateral_file = join(self.exp_dir, 'senses_fps_bilateral.txt') |
|
from froc_by_pranjal import get_froc_points |
|
exp_name = self.config['EXP_NAME'] |
|
if mode == 'FRCNN' or mode == 'FRCNN_BILATERAL': |
|
senses, fps = get_froc_points(f'preds_frcnn_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT']), save_to = frcnn_file) |
|
self.logger.log('FRCNN RESULTS', priority = LogPriority.STATS) |
|
for s,f in zip(senses, fps): |
|
self.logger.log(f'Sensitivty at {f}: {s}', priority = LogPriority.STATS) |
|
if mode == 'BILATERAL' or mode == 'FRCNN_BILATERAL': |
|
senses, fps = get_froc_points(f'preds_bilateral_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT']), save_to = bilateral_file) |
|
self.logger.log('BILATERAL RESULTS', priority = LogPriority.STATS) |
|
for s,f in zip(senses, fps): |
|
self.logger.log(f'Sensitivty at {f}: {s}', priority = LogPriority.STATS) |
|
|
|
|
|
|
|
|
|
if mode == 'FRCNN': |
|
plot_froc({frcnn_file : 'FRCNN'}, join(self.exp_dir,'plot.png'), TITLE = 'FRCNN FROC') |
|
elif mode == 'BILATERAL': |
|
plot_froc({bilateral_file : 'BILATERAL'}, join(self.exp_dir,'plot.png'), TITLE = 'BILATERAL FROC') |
|
elif mode == 'FRCNN_BILATERAL': |
|
plot_froc({frcnn_file : 'FRCNN', bilateral_file : 'BILATERAL'}, join(self.exp_dir,'plot.png'), TITLE = 'FRCNN vs BILATERAL FROC') |
|
self.logger.flush() |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--cfg_file', type=str, default='configs/AIIMS_C1.cfg') |
|
args = parser.parse_args() |
|
exp = Experimenter(args.cfg_file) |
|
exp.run_experiment() |