File size: 7,111 Bytes
b9425fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# --------------------------------------------------------
# Reversible Column Networks
# Copyright (c) 2022 Megvii Inc.
# Licensed under The Apache License 2.0 [see LICENSE for details]
# Written by Yuxuan Cai
# --------------------------------------------------------
import io
import os
import re
from typing import List
from timm.utils.model_ema import ModelEma
import torch
import torch.distributed as dist
from timm.utils import get_state_dict
import subprocess
def load_checkpoint(config, model, optimizer, logger, model_ema=None):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
logger.info("Already loaded checkpoint to memory..")
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
if config.MODEL_EMA:
if 'state_dict_ema' in checkpoint.keys():
model_ema.ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
logger.info("Loaded state_dict_ema")
else:
model_ema.ema.load_state_dict(checkpoint['model'], strict=False)
logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
# del checkpoint
# torch.cuda.empty_cache()
return max_accuracy
def load_checkpoint_finetune(config, model, logger, model_ema=None):
logger.info(f"==============> Finetune {config.MODEL.FINETUNE}....................")
checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')['model']
converted_weights = {}
keys = list(checkpoint.keys())
for key in keys:
if re.match(r'cls.*', key):
# if re.match(r'cls.classifier.1.*', key):
print(f'key: {key} is used for pretrain, discarded.')
continue
else:
converted_weights[key] = checkpoint[key]
msg = model.load_state_dict(converted_weights, strict=False)
logger.info(msg)
if model_ema is not None:
ema_msg = model_ema.ema.load_state_dict(converted_weights, strict=False)
logger.info(f"==============> Loaded Pretraind statedict into EMA....................")
logger.info(ema_msg)
del checkpoint
torch.cuda.empty_cache()
def save_checkpoint(config, epoch, model, epoch_accuracy, max_accuracy, optimizer, logger, model_ema=None):
if model_ema is not None:
logger.info("Model EMA is not None...")
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'max_accuracy': max(max_accuracy, epoch_accuracy),
'epoch': epoch,
'state_dict_ema': get_state_dict(model_ema),
'input': input,
'config': config}
else:
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'max_accuracy': max(max_accuracy, epoch_accuracy),
'epoch': epoch,
'state_dict_ema': None,
'input': input,
'config': config}
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
best_path = os.path.join(config.OUTPUT, f'best.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
if epoch_accuracy>max_accuracy:
torch.save(save_state, best_path)
logger.info(f"{save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir,logger):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and ckpt.startswith('ckpt_')]
logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
logger.info(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def denormalize(tensor: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
"""Denormalize a float tensor image with mean and standard deviation.
This transform does not support PIL Image.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~torchvision.transforms.Normalize` for more details.
Args:
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Denormalized Tensor image.
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
if not tensor.is_floating_point():
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype))
if tensor.ndim < 3:
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size()))
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.mul_(std).add_(mean).clip_(0.0, 1.0)
return tensor
|