import io |
import os |
import re |
from typing import List |
from timm.utils.model_ema import ModelEma |
import torch |
import torch.distributed as dist |
from timm.utils import get_state_dict |
import subprocess |
def load_checkpoint(config, model, optimizer, logger, model_ema=None): |
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") |
if config.MODEL.RESUME.startswith('https'): |
checkpoint = torch.hub.load_state_dict_from_url( |
config.MODEL.RESUME, map_location='cpu', check_hash=True) |
else: |
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
logger.info("Already loaded checkpoint to memory..") |
msg = model.load_state_dict(checkpoint['model'], strict=False) |
logger.info(msg) |
max_accuracy = 0.0 |
if config.MODEL_EMA: |
if 'state_dict_ema' in checkpoint.keys(): |
model_ema.ema.load_state_dict(checkpoint['state_dict_ema'], strict=False) |
logger.info("Loaded state_dict_ema") |
else: |
model_ema.ema.load_state_dict(checkpoint['model'], strict=False) |
logger.warning("Failed to find state_dict_ema, starting from loaded model weights") |
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'epoch' in checkpoint: |
optimizer.load_state_dict(checkpoint['optimizer']) |
config.defrost() |
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
config.freeze() |
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
if 'max_accuracy' in checkpoint: |
max_accuracy = checkpoint['max_accuracy'] |
return max_accuracy |
def load_checkpoint_finetune(config, model, logger, model_ema=None): |
logger.info(f"==============> Finetune {config.MODEL.FINETUNE}....................") |
checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')['model'] |
converted_weights = {} |
keys = list(checkpoint.keys()) |
for key in keys: |
if re.match(r'cls.*', key): |
print(f'key: {key} is used for pretrain, discarded.') |
continue |
else: |
converted_weights[key] = checkpoint[key] |
msg = model.load_state_dict(converted_weights, strict=False) |
logger.info(msg) |
if model_ema is not None: |
ema_msg = model_ema.ema.load_state_dict(converted_weights, strict=False) |
logger.info(f"==============> Loaded Pretraind statedict into EMA....................") |
logger.info(ema_msg) |
del checkpoint |
torch.cuda.empty_cache() |
def save_checkpoint(config, epoch, model, epoch_accuracy, max_accuracy, optimizer, logger, model_ema=None): |
if model_ema is not None: |
logger.info("Model EMA is not None...") |
save_state = {'model': model.state_dict(), |
'optimizer': optimizer.state_dict(), |
'max_accuracy': max(max_accuracy, epoch_accuracy), |
'epoch': epoch, |
'state_dict_ema': get_state_dict(model_ema), |
'input': input, |
'config': config} |
else: |
save_state = {'model': model.state_dict(), |
'optimizer': optimizer.state_dict(), |
'max_accuracy': max(max_accuracy, epoch_accuracy), |
'epoch': epoch, |
'state_dict_ema': None, |
'input': input, |
'config': config} |
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
best_path = os.path.join(config.OUTPUT, f'best.pth') |
logger.info(f"{save_path} saving......") |
torch.save(save_state, save_path) |
if epoch_accuracy>max_accuracy: |
torch.save(save_state, best_path) |
logger.info(f"{save_path} saved !!!") |
def get_grad_norm(parameters, norm_type=2): |
if isinstance(parameters, torch.Tensor): |
parameters = [parameters] |
parameters = list(filter(lambda p: p.grad is not None, parameters)) |
norm_type = float(norm_type) |
total_norm = 0 |
for p in parameters: |
param_norm = p.grad.data.norm(norm_type) |
total_norm += param_norm.item() ** norm_type |
total_norm = total_norm ** (1. / norm_type) |
return total_norm |
def auto_resume_helper(output_dir,logger): |
checkpoints = os.listdir(output_dir) |
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and ckpt.startswith('ckpt_')] |
logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") |
if len(checkpoints) > 0: |
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) |
logger.info(f"The latest checkpoint founded: {latest_checkpoint}") |
resume_file = latest_checkpoint |
else: |
resume_file = None |
return resume_file |
def reduce_tensor(tensor): |
rt = tensor.clone() |
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
rt /= dist.get_world_size() |
return rt |
def denormalize(tensor: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: |
"""Denormalize a float tensor image with mean and standard deviation. |
This transform does not support PIL Image. |
.. note:: |
This transform acts out of place by default, i.e., it does not mutates the input tensor. |
See :class:`~torchvision.transforms.Normalize` for more details. |
Args: |
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized. |
mean (sequence): Sequence of means for each channel. |
std (sequence): Sequence of standard deviations for each channel. |
inplace(bool,optional): Bool to make this operation inplace. |
Returns: |
Tensor: Denormalized Tensor image. |
""" |
if not isinstance(tensor, torch.Tensor): |
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor))) |
if not tensor.is_floating_point(): |
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype)) |
if tensor.ndim < 3: |
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ' |
'{}.'.format(tensor.size())) |
if not inplace: |
tensor = tensor.clone() |
dtype = tensor.dtype |
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) |
std = torch.as_tensor(std, dtype=dtype, device=tensor.device) |
if (std == 0).any(): |
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) |
if mean.ndim == 1: |
mean = mean.view(-1, 1, 1) |
if std.ndim == 1: |
std = std.view(-1, 1, 1) |
tensor.mul_(std).add_(mean).clip_(0.0, 1.0) |
return tensor |