Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
import torch | |
import torch.distributed as dist | |
from fvcore.nn.distributed import differentiable_all_reduce | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.utils import comm, env | |
from .wrappers import BatchNorm2d | |
class FrozenBatchNorm2d(nn.Module): | |
""" | |
BatchNorm2d where the batch statistics and the affine parameters are fixed. | |
It contains non-trainable buffers called | |
"weight" and "bias", "running_mean", "running_var", | |
initialized to perform identity transformation. | |
The pre-trained backbone models from Caffe2 only contain "weight" and "bias", | |
which are computed from the original four parameters of BN. | |
The affine transform `x * weight + bias` will perform the equivalent | |
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. | |
When loading a backbone model from Caffe2, "running_mean" and "running_var" | |
will be left unchanged as identity transformation. | |
Other pre-trained backbone models may contain all 4 parameters. | |
The forward is implemented by `F.batch_norm(..., training=False)`. | |
""" | |
_version = 3 | |
def __init__(self, num_features, eps=1e-5): | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.register_buffer("weight", torch.ones(num_features)) | |
self.register_buffer("bias", torch.zeros(num_features)) | |
self.register_buffer("running_mean", torch.zeros(num_features)) | |
self.register_buffer("running_var", torch.ones(num_features) - eps) | |
self.register_buffer("num_batches_tracked", None) | |
def forward(self, x): | |
if x.requires_grad: | |
# When gradients are needed, F.batch_norm will use extra memory | |
# because its backward op computes gradients for weight/bias as well. | |
scale = self.weight * (self.running_var + self.eps).rsqrt() | |
bias = self.bias - self.running_mean * scale | |
scale = scale.reshape(1, -1, 1, 1) | |
bias = bias.reshape(1, -1, 1, 1) | |
out_dtype = x.dtype # may be half | |
return x * scale.to(out_dtype) + bias.to(out_dtype) | |
else: | |
# When gradients are not needed, F.batch_norm is a single fused op | |
# and provide more optimization opportunities. | |
return F.batch_norm( | |
x, | |
self.running_mean, | |
self.running_var, | |
self.weight, | |
self.bias, | |
training=False, | |
eps=self.eps, | |
) | |
def _load_from_state_dict( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
version = local_metadata.get("version", None) | |
if version is None or version < 2: | |
# No running_mean/var in early versions | |
# This will silent the warnings | |
if prefix + "running_mean" not in state_dict: | |
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) | |
if prefix + "running_var" not in state_dict: | |
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) | |
super()._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |
def __repr__(self): | |
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) | |
def convert_frozen_batchnorm(cls, module): | |
""" | |
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. | |
Args: | |
module (torch.nn.Module): | |
Returns: | |
If module is BatchNorm/SyncBatchNorm, returns a new module. | |
Otherwise, in-place convert module and return it. | |
Similar to convert_sync_batchnorm in | |
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py | |
""" | |
bn_module = nn.modules.batchnorm | |
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) | |
res = module | |
if isinstance(module, bn_module): | |
res = cls(module.num_features) | |
if module.affine: | |
res.weight.data = module.weight.data.clone().detach() | |
res.bias.data = module.bias.data.clone().detach() | |
res.running_mean.data = module.running_mean.data | |
res.running_var.data = module.running_var.data | |
res.eps = module.eps | |
res.num_batches_tracked = module.num_batches_tracked | |
else: | |
for name, child in module.named_children(): | |
new_child = cls.convert_frozen_batchnorm(child) | |
if new_child is not child: | |
res.add_module(name, new_child) | |
return res | |
def convert_frozenbatchnorm2d_to_batchnorm2d(cls, module: nn.Module) -> nn.Module: | |
""" | |
Convert all FrozenBatchNorm2d to BatchNorm2d | |
Args: | |
module (torch.nn.Module): | |
Returns: | |
If module is FrozenBatchNorm2d, returns a new module. | |
Otherwise, in-place convert module and return it. | |
This is needed for quantization: | |
https://fb.workplace.com/groups/1043663463248667/permalink/1296330057982005/ | |
""" | |
res = module | |
if isinstance(module, FrozenBatchNorm2d): | |
res = torch.nn.BatchNorm2d(module.num_features, module.eps) | |
res.weight.data = module.weight.data.clone().detach() | |
res.bias.data = module.bias.data.clone().detach() | |
res.running_mean.data = module.running_mean.data.clone().detach() | |
res.running_var.data = module.running_var.data.clone().detach() | |
res.eps = module.eps | |
res.num_batches_tracked = module.num_batches_tracked | |
else: | |
for name, child in module.named_children(): | |
new_child = cls.convert_frozenbatchnorm2d_to_batchnorm2d(child) | |
if new_child is not child: | |
res.add_module(name, new_child) | |
return res | |
def get_norm(norm, out_channels): | |
""" | |
Args: | |
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; | |
or a callable that takes a channel number and returns | |
the normalization layer as a nn.Module. | |
Returns: | |
nn.Module or None: the normalization layer | |
""" | |
if norm is None: | |
return None | |
if isinstance(norm, str): | |
if len(norm) == 0: | |
return None | |
norm = { | |
"BN": BatchNorm2d, | |
# Fixed in https://github.com/pytorch/pytorch/pull/36382 | |
"SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm, | |
"FrozenBN": FrozenBatchNorm2d, | |
"GN": lambda channels: nn.GroupNorm(32, channels), | |
# for debugging: | |
"nnSyncBN": nn.SyncBatchNorm, | |
"naiveSyncBN": NaiveSyncBatchNorm, | |
# expose stats_mode N as an option to caller, required for zero-len inputs | |
"naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"), | |
"LN": lambda channels: LayerNorm(channels), | |
}[norm] | |
return norm(out_channels) | |
class NaiveSyncBatchNorm(BatchNorm2d): | |
""" | |
In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient | |
when the batch size on each worker is different. | |
(e.g., when scale augmentation is used, or when it is applied to mask head). | |
This is a slower but correct alternative to `nn.SyncBatchNorm`. | |
Note: | |
There isn't a single definition of Sync BatchNorm. | |
When ``stats_mode==""``, this module computes overall statistics by using | |
statistics of each worker with equal weight. The result is true statistics | |
of all samples (as if they are all on one worker) only when all workers | |
have the same (N, H, W). This mode does not support inputs with zero batch size. | |
When ``stats_mode=="N"``, this module computes overall statistics by weighting | |
the statistics of each worker by their ``N``. The result is true statistics | |
of all samples (as if they are all on one worker) only when all workers | |
have the same (H, W). It is slower than ``stats_mode==""``. | |
Even though the result of this module may not be the true statistics of all samples, | |
it may still be reasonable because it might be preferrable to assign equal weights | |
to all workers, regardless of their (H, W) dimension, instead of putting larger weight | |
on larger images. From preliminary experiments, little difference is found between such | |
a simplified implementation and an accurate computation of overall mean & variance. | |
""" | |
def __init__(self, *args, stats_mode="", **kwargs): | |
super().__init__(*args, **kwargs) | |
assert stats_mode in ["", "N"] | |
self._stats_mode = stats_mode | |
def forward(self, input): | |
if comm.get_world_size() == 1 or not self.training: | |
return super().forward(input) | |
B, C = input.shape[0], input.shape[1] | |
half_input = input.dtype == torch.float16 | |
if half_input: | |
# fp16 does not have good enough numerics for the reduction here | |
input = input.float() | |
mean = torch.mean(input, dim=[0, 2, 3]) | |
meansqr = torch.mean(input * input, dim=[0, 2, 3]) | |
if self._stats_mode == "": | |
assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' | |
vec = torch.cat([mean, meansqr], dim=0) | |
vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) | |
mean, meansqr = torch.split(vec, C) | |
momentum = self.momentum | |
else: | |
if B == 0: | |
vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) | |
vec = vec + input.sum() # make sure there is gradient w.r.t input | |
else: | |
vec = torch.cat( | |
[ | |
mean, | |
meansqr, | |
torch.ones([1], device=mean.device, dtype=mean.dtype), | |
], | |
dim=0, | |
) | |
vec = differentiable_all_reduce(vec * B) | |
total_batch = vec[-1].detach() | |
momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 | |
mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C) # avoid div-by-zero | |
var = meansqr - mean * mean | |
invstd = torch.rsqrt(var + self.eps) | |
scale = self.weight * invstd | |
bias = self.bias - mean * scale | |
scale = scale.reshape(1, -1, 1, 1) | |
bias = bias.reshape(1, -1, 1, 1) | |
self.running_mean += momentum * (mean.detach() - self.running_mean) | |
self.running_var += momentum * (var.detach() - self.running_var) | |
ret = input * scale + bias | |
if half_input: | |
ret = ret.half() | |
return ret | |
class CycleBatchNormList(nn.ModuleList): | |
""" | |
Implement domain-specific BatchNorm by cycling. | |
When a BatchNorm layer is used for multiple input domains or input | |
features, it might need to maintain a separate test-time statistics | |
for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`. | |
This module implements it by using N separate BN layers | |
and it cycles through them every time a forward() is called. | |
NOTE: The caller of this module MUST guarantee to always call | |
this module by multiple of N times. Otherwise its test-time statistics | |
will be incorrect. | |
""" | |
def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs): | |
""" | |
Args: | |
length: number of BatchNorm layers to cycle. | |
bn_class: the BatchNorm class to use | |
kwargs: arguments of the BatchNorm class, such as num_features. | |
""" | |
self._affine = kwargs.pop("affine", True) | |
super().__init__([bn_class(**kwargs, affine=False) for k in range(length)]) | |
if self._affine: | |
# shared affine, domain-specific BN | |
channels = self[0].num_features | |
self.weight = nn.Parameter(torch.ones(channels)) | |
self.bias = nn.Parameter(torch.zeros(channels)) | |
self._pos = 0 | |
def forward(self, x): | |
ret = self[self._pos](x) | |
self._pos = (self._pos + 1) % len(self) | |
if self._affine: | |
w = self.weight.reshape(1, -1, 1, 1) | |
b = self.bias.reshape(1, -1, 1, 1) | |
return ret * w + b | |
else: | |
return ret | |
def extra_repr(self): | |
return f"affine={self._affine}" | |
class LayerNorm(nn.Module): | |
""" | |
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and | |
variance normalization over the channel dimension for inputs that have shape | |
(batch_size, channels, height, width). | |
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 | |
""" | |
def __init__(self, normalized_shape, eps=1e-6): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.eps = eps | |
self.normalized_shape = (normalized_shape,) | |
def forward(self, x): | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |