import os import time import argparse import numpy as np from tqdm import tqdm from glob import glob import torch from torch.utils.data import DataLoader from video_module.dataset import Water_Image_Train_DS from video_module.model import AFB_URR, FeatureBank import myutils # Enable CUDA launch blocking for debugging; set to '0' to disable. os.environ['CUDA_LAUNCH_BLOCKING'] = '0' def get_args(): parser = argparse.ArgumentParser(description='Train V-BeachNet') parser.add_argument('--gpu', type=int, default=0, help='GPU card id.') parser.add_argument('--dataset', type=str, required=True, help='Dataset folder.') parser.add_argument('--seed', type=int, default=-1, help='Random seed.') parser.add_argument('--log', action='store_true', help='Save the training results.') parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate (default: 1e-5).') parser.add_argument('--lu', type=float, default=0.5, help='Regularization factor (default: 0.5).') parser.add_argument('--resume', type=str, help='Path to checkpoint (default: none).') parser.add_argument('--new', action='store_true', help='Train the model from scratch.') parser.add_argument('--scheduler-step', type=int, default=25, help='Scheduler step size (default: 25).') parser.add_argument('--total-epochs', type=int, default=100, help='Total number of epochs (default: 100).') parser.add_argument('--budget', type=int, default=300000, help='Maximum number of features in the feature bank (default: 300000).') parser.add_argument('--obj-n', type=int, default=2, help='Maximum number of objects trained simultaneously.') parser.add_argument('--clip-n', type=int, default=6, help='Maximum number of frames in a batch.') return parser.parse_args() def train_model(model, dataloader, criterion, optimizer): stats = myutils.AvgMeter() uncertainty_stats = myutils.AvgMeter() progress_bar = tqdm(dataloader) for _, sample in enumerate(progress_bar): frames, masks, obj_n, info = sample if obj_n.item() == 1: continue frames, masks = frames[0].to(device), masks[0].to(device) fb_global = FeatureBank(obj_n.item(), args.budget, device) k4_list, v4_list = model.memorize(frames[:1], masks[:1]) fb_global.init_bank(k4_list, v4_list) scores, uncertainty = model.segment(frames[1:], fb_global) label = torch.argmax(masks[1:], dim=1).long() optimizer.zero_grad() loss = criterion(scores, label) + args.lu * uncertainty loss.backward() optimizer.step() uncertainty_stats.update(uncertainty.item()) stats.update(loss.item()) progress_bar.set_postfix(loss=f'{loss.item():.5f} (Avg: {stats.avg:.5f}, Uncertainty Avg: {uncertainty_stats.avg:.5f})') return stats.avg def main(): dataset = Water_Image_Train_DS(root=args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n) dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True) print(myutils.gct(), f'Dataset with {len(dataset)} training cases.') model = AFB_URR(device, update_bank=False, load_imagenet_params=True).to(device) model.train() model.apply(myutils.set_bn_eval) optimizer = torch.optim.AdamW(filter(lambda x: x.requires_grad, model.parameters()), args.lr) start_epoch, best_loss = 0, float('inf') if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['model'], strict=False) seed = checkpoint.get('seed', int(time.time())) if not args.new: start_epoch = checkpoint['epoch'] + 1 optimizer.load_state_dict(checkpoint['optimizer']) best_loss = checkpoint['loss'] print(myutils.gct(), f'Resumed from checkpoint {args.resume} (Epoch: {start_epoch-1}, Best Loss: {best_loss}).') else: print(myutils.gct(), f'Loaded checkpoint {args.resume}. Training from scratch.') else: raise FileNotFoundError(f'No checkpoint found at {args.resume}') else: seed = args.seed if args.seed >= 0 else int(time.time()) print(myutils.gct(), 'Random seed:', seed) torch.manual_seed(seed) np.random.seed(seed) criterion = torch.nn.CrossEntropyLoss().to(device) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.5, last_epoch=start_epoch-1) for epoch in range(start_epoch, args.total_epochs): print(f'\n{myutils.gct()} Epoch: {epoch}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}') loss = train_model(model, dataloader, criterion, optimizer) if args.log: checkpoint = { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': loss, 'seed': seed } torch.save(checkpoint, os.path.join(model_path, 'final.pth')) if loss < best_loss: best_loss = loss torch.save(checkpoint, os.path.join(model_path, 'best.pth')) print('Updated best model.') scheduler.step() if __name__ == '__main__': args = get_args() print(myutils.gct(), f'Arguments: {args}') if args.gpu >= 0 and torch.cuda.is_available(): device = torch.device('cuda', args.gpu) else: raise ValueError('CUDA is required. Ensure --gpu is set to >= 0.') if args.log: log_dir = os.path.join('logs', time.strftime('%Y%m%d-%H%M%S')) model_path = os.path.join(log_dir, 'model') os.makedirs(model_path, exist_ok=True) myutils.save_scripts(log_dir, scripts_to_save=glob('*.*')) myutils.save_scripts(log_dir, scripts_to_save=glob('dataset/*.py', recursive=True)) myutils.save_scripts(log_dir, scripts_to_save=glob('model/*.py', recursive=True)) myutils.save_scripts(log_dir, scripts_to_save=glob('myutils/*.py', recursive=True)) print(myutils.gct(), f'Created log directory: {log_dir}') main() print(myutils.gct(), 'Training completed.')