Spaces:
Runtime error
Runtime error
# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py | |
from __future__ import division | |
import logging | |
from utils import CheckpointSaver | |
from tensorboardX import SummaryWriter | |
import torch | |
from tqdm import tqdm | |
tqdm.monitor_interval = 0 | |
logger = logging.getLogger(__name__) | |
class BaseTrainer(object): | |
"""Base class for Trainer objects. | |
Takes care of checkpointing/logging/resuming training. | |
""" | |
def __init__(self, options): | |
self.options = options | |
if options.multiprocessing_distributed: | |
self.device = torch.device('cuda', options.gpu) | |
else: | |
self.device = torch.device( | |
'cuda' if torch.cuda.is_available() else 'cpu') | |
# override this function to define your model, optimizers etc. | |
self.saver = CheckpointSaver(save_dir=options.checkpoint_dir, | |
overwrite=options.overwrite) | |
if options.rank == 0: | |
self.summary_writer = SummaryWriter(self.options.summary_dir) | |
self.init_fn() | |
self.checkpoint = None | |
if options.resume and self.saver.exists_checkpoint(): | |
self.checkpoint = self.saver.load_checkpoint( | |
self.models_dict, self.optimizers_dict) | |
if self.checkpoint is None: | |
self.epoch_count = 0 | |
self.step_count = 0 | |
else: | |
self.epoch_count = self.checkpoint['epoch'] | |
self.step_count = self.checkpoint['total_step_count'] | |
if self.checkpoint is not None: | |
self.checkpoint_batch_idx = self.checkpoint['batch_idx'] | |
else: | |
self.checkpoint_batch_idx = 0 | |
self.best_performance = float('inf') | |
def load_pretrained(self, checkpoint_file=None): | |
"""Load a pretrained checkpoint. | |
This is different from resuming training using --resume. | |
""" | |
if checkpoint_file is not None: | |
checkpoint = torch.load(checkpoint_file) | |
for model in self.models_dict: | |
if model in checkpoint: | |
self.models_dict[model].load_state_dict(checkpoint[model], | |
strict=True) | |
print(f'Checkpoint {model} loaded') | |
def move_dict_to_device(self, dict, device, tensor2float=False): | |
for k, v in dict.items(): | |
if isinstance(v, torch.Tensor): | |
if tensor2float: | |
dict[k] = v.float().to(device) | |
else: | |
dict[k] = v.to(device) | |
# The following methods (with the possible exception of test) have to be implemented in the derived classes | |
def train(self, epoch): | |
raise NotImplementedError('You need to provide an train method') | |
def init_fn(self): | |
raise NotImplementedError('You need to provide an _init_fn method') | |
def train_step(self, input_batch): | |
raise NotImplementedError('You need to provide a _train_step method') | |
def train_summaries(self, input_batch): | |
raise NotImplementedError( | |
'You need to provide a _train_summaries method') | |
def visualize(self, input_batch): | |
raise NotImplementedError('You need to provide a visualize method') | |
def validate(self): | |
pass | |
def test(self): | |
pass | |
def evaluate(self): | |
pass | |
def fit(self): | |
# Run training for num_epochs epochs | |
for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), | |
total=self.options.num_epochs, | |
initial=self.epoch_count): | |
self.epoch_count = epoch | |
self.train(epoch) | |
return | |