import datetime import torch def str2bool(s): return s.lower() in ('true', '1') class Timer: def __init__(self): self.clock = {} def start(self, key="default"): self.clock[key] = datetime.datetime.now() def end(self, key="default"): if key not in self.clock: raise Exception(f"{key} is not in the clock.") interval = datetime.datetime.now() - self.clock[key] del self.clock[key] return interval.total_seconds() def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path): torch.save({ 'epoch': epoch, 'model': net_state_dict, 'optimizer': optimizer_state_dict, 'best_score': best_score }, checkpoint_path) torch.save(net_state_dict, model_path) def load_checkpoint(checkpoint_path): return torch.load(checkpoint_path) def freeze_net_layers(net): for param in net.parameters(): param.requires_grad = False def store_labels(path, labels): with open(path, "w") as f: f.write("\n".join(labels))