|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
import os |
|
import re |
|
from typing import List |
|
from timm.utils.model_ema import ModelEma |
|
import torch |
|
import torch.distributed as dist |
|
from timm.utils import get_state_dict |
|
import subprocess |
|
|
|
|
|
|
|
|
|
def load_checkpoint(config, model, optimizer, logger, model_ema=None): |
|
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") |
|
if config.MODEL.RESUME.startswith('https'): |
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
config.MODEL.RESUME, map_location='cpu', check_hash=True) |
|
else: |
|
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
|
logger.info("Already loaded checkpoint to memory..") |
|
msg = model.load_state_dict(checkpoint['model'], strict=False) |
|
logger.info(msg) |
|
max_accuracy = 0.0 |
|
if config.MODEL_EMA: |
|
if 'state_dict_ema' in checkpoint.keys(): |
|
model_ema.ema.load_state_dict(checkpoint['state_dict_ema'], strict=False) |
|
|
|
logger.info("Loaded state_dict_ema") |
|
else: |
|
model_ema.ema.load_state_dict(checkpoint['model'], strict=False) |
|
logger.warning("Failed to find state_dict_ema, starting from loaded model weights") |
|
|
|
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'epoch' in checkpoint: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
config.defrost() |
|
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
|
config.freeze() |
|
|
|
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
|
if 'max_accuracy' in checkpoint: |
|
max_accuracy = checkpoint['max_accuracy'] |
|
|
|
|
|
return max_accuracy |
|
|
|
def load_checkpoint_finetune(config, model, logger, model_ema=None): |
|
logger.info(f"==============> Finetune {config.MODEL.FINETUNE}....................") |
|
checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')['model'] |
|
converted_weights = {} |
|
keys = list(checkpoint.keys()) |
|
for key in keys: |
|
if re.match(r'cls.*', key): |
|
|
|
print(f'key: {key} is used for pretrain, discarded.') |
|
continue |
|
else: |
|
converted_weights[key] = checkpoint[key] |
|
msg = model.load_state_dict(converted_weights, strict=False) |
|
logger.info(msg) |
|
if model_ema is not None: |
|
ema_msg = model_ema.ema.load_state_dict(converted_weights, strict=False) |
|
logger.info(f"==============> Loaded Pretraind statedict into EMA....................") |
|
logger.info(ema_msg) |
|
del checkpoint |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def save_checkpoint(config, epoch, model, epoch_accuracy, max_accuracy, optimizer, logger, model_ema=None): |
|
if model_ema is not None: |
|
logger.info("Model EMA is not None...") |
|
save_state = {'model': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'max_accuracy': max(max_accuracy, epoch_accuracy), |
|
'epoch': epoch, |
|
'state_dict_ema': get_state_dict(model_ema), |
|
'input': input, |
|
'config': config} |
|
else: |
|
save_state = {'model': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'max_accuracy': max(max_accuracy, epoch_accuracy), |
|
'epoch': epoch, |
|
'state_dict_ema': None, |
|
'input': input, |
|
'config': config} |
|
|
|
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
|
best_path = os.path.join(config.OUTPUT, f'best.pth') |
|
|
|
logger.info(f"{save_path} saving......") |
|
torch.save(save_state, save_path) |
|
if epoch_accuracy>max_accuracy: |
|
torch.save(save_state, best_path) |
|
logger.info(f"{save_path} saved !!!") |
|
|
|
|
|
def get_grad_norm(parameters, norm_type=2): |
|
if isinstance(parameters, torch.Tensor): |
|
parameters = [parameters] |
|
parameters = list(filter(lambda p: p.grad is not None, parameters)) |
|
norm_type = float(norm_type) |
|
total_norm = 0 |
|
for p in parameters: |
|
param_norm = p.grad.data.norm(norm_type) |
|
total_norm += param_norm.item() ** norm_type |
|
total_norm = total_norm ** (1. / norm_type) |
|
return total_norm |
|
|
|
|
|
def auto_resume_helper(output_dir,logger): |
|
checkpoints = os.listdir(output_dir) |
|
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and ckpt.startswith('ckpt_')] |
|
logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") |
|
if len(checkpoints) > 0: |
|
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) |
|
logger.info(f"The latest checkpoint founded: {latest_checkpoint}") |
|
resume_file = latest_checkpoint |
|
else: |
|
resume_file = None |
|
return resume_file |
|
|
|
|
|
def reduce_tensor(tensor): |
|
rt = tensor.clone() |
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
rt /= dist.get_world_size() |
|
return rt |
|
|
|
def denormalize(tensor: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: |
|
"""Denormalize a float tensor image with mean and standard deviation. |
|
This transform does not support PIL Image. |
|
|
|
.. note:: |
|
This transform acts out of place by default, i.e., it does not mutates the input tensor. |
|
|
|
See :class:`~torchvision.transforms.Normalize` for more details. |
|
|
|
Args: |
|
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized. |
|
mean (sequence): Sequence of means for each channel. |
|
std (sequence): Sequence of standard deviations for each channel. |
|
inplace(bool,optional): Bool to make this operation inplace. |
|
|
|
Returns: |
|
Tensor: Denormalized Tensor image. |
|
""" |
|
if not isinstance(tensor, torch.Tensor): |
|
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor))) |
|
|
|
if not tensor.is_floating_point(): |
|
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype)) |
|
|
|
if tensor.ndim < 3: |
|
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ' |
|
'{}.'.format(tensor.size())) |
|
|
|
if not inplace: |
|
tensor = tensor.clone() |
|
|
|
dtype = tensor.dtype |
|
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) |
|
std = torch.as_tensor(std, dtype=dtype, device=tensor.device) |
|
if (std == 0).any(): |
|
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) |
|
if mean.ndim == 1: |
|
mean = mean.view(-1, 1, 1) |
|
if std.ndim == 1: |
|
std = std.view(-1, 1, 1) |
|
tensor.mul_(std).add_(mean).clip_(0.0, 1.0) |
|
return tensor |
|
|
|
|