RevCol / training /utils.py
LarryTsai's picture
Training Code:cls/det
b9425fd
# --------------------------------------------------------
# Reversible Column Networks
# Copyright (c) 2022 Megvii Inc.
# Licensed under The Apache License 2.0 [see LICENSE for details]
# Written by Yuxuan Cai
# --------------------------------------------------------
import io
import os
import re
from typing import List
from timm.utils.model_ema import ModelEma
import torch
import torch.distributed as dist
from timm.utils import get_state_dict
import subprocess
def load_checkpoint(config, model, optimizer, logger, model_ema=None):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
logger.info("Already loaded checkpoint to memory..")
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
if config.MODEL_EMA:
if 'state_dict_ema' in checkpoint.keys():
model_ema.ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
logger.info("Loaded state_dict_ema")
else:
model_ema.ema.load_state_dict(checkpoint['model'], strict=False)
logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
# del checkpoint
# torch.cuda.empty_cache()
return max_accuracy
def load_checkpoint_finetune(config, model, logger, model_ema=None):
logger.info(f"==============> Finetune {config.MODEL.FINETUNE}....................")
checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')['model']
converted_weights = {}
keys = list(checkpoint.keys())
for key in keys:
if re.match(r'cls.*', key):
# if re.match(r'cls.classifier.1.*', key):
print(f'key: {key} is used for pretrain, discarded.')
continue
else:
converted_weights[key] = checkpoint[key]
msg = model.load_state_dict(converted_weights, strict=False)
logger.info(msg)
if model_ema is not None:
ema_msg = model_ema.ema.load_state_dict(converted_weights, strict=False)
logger.info(f"==============> Loaded Pretraind statedict into EMA....................")
logger.info(ema_msg)
del checkpoint
torch.cuda.empty_cache()
def save_checkpoint(config, epoch, model, epoch_accuracy, max_accuracy, optimizer, logger, model_ema=None):
if model_ema is not None:
logger.info("Model EMA is not None...")
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'max_accuracy': max(max_accuracy, epoch_accuracy),
'epoch': epoch,
'state_dict_ema': get_state_dict(model_ema),
'input': input,
'config': config}
else:
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'max_accuracy': max(max_accuracy, epoch_accuracy),
'epoch': epoch,
'state_dict_ema': None,
'input': input,
'config': config}
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
best_path = os.path.join(config.OUTPUT, f'best.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
if epoch_accuracy>max_accuracy:
torch.save(save_state, best_path)
logger.info(f"{save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir,logger):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and ckpt.startswith('ckpt_')]
logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
logger.info(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def denormalize(tensor: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
"""Denormalize a float tensor image with mean and standard deviation.
This transform does not support PIL Image.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~torchvision.transforms.Normalize` for more details.
Args:
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Denormalized Tensor image.
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
if not tensor.is_floating_point():
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype))
if tensor.ndim < 3:
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size()))
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.mul_(std).add_(mean).clip_(0.0, 1.0)
return tensor