RevCol / training /models /revcol.py
LarryTsai's picture
Training Code:cls/det
b9425fd
raw
history blame
10.6 kB
# --------------------------------------------------------
# Reversible Column Networks
# Copyright (c) 2022 Megvii Inc.
# Licensed under The Apache License 2.0 [see LICENSE for details]
# Written by Yuxuan Cai
# --------------------------------------------------------
import torch
import torch.nn as nn
from models.modules import ConvNextBlock, Decoder, LayerNorm, SimDecoder, UpSampleConvnext
import torch.distributed as dist
from models.revcol_function import ReverseFunction
from timm.models.layers import trunc_normal_
class Fusion(nn.Module):
def __init__(self, level, channels, first_col) -> None:
super().__init__()
self.level = level
self.first_col = first_col
self.down = nn.Sequential(
nn.Conv2d(channels[level-1], channels[level], kernel_size=2, stride=2),
LayerNorm(channels[level], eps=1e-6, data_format="channels_first"),
) if level in [1, 2, 3] else nn.Identity()
if not first_col:
self.up = UpSampleConvnext(1, channels[level+1], channels[level]) if level in [0, 1, 2] else nn.Identity()
def forward(self, *args):
c_down, c_up = args
if self.first_col:
x = self.down(c_down)
return x
if self.level == 3:
x = self.down(c_down)
else:
x = self.up(c_up) + self.down(c_down)
return x
class Level(nn.Module):
def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None:
super().__init__()
countlayer = sum(layers[:level])
expansion = 4
self.fusion = Fusion(level, channels, first_col)
modules = [ConvNextBlock(channels[level], expansion*channels[level], channels[level], kernel_size = kernel_size, layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer+i]) for i in range(layers[level])]
self.blocks = nn.Sequential(*modules)
def forward(self, *args):
x = self.fusion(*args)
x = self.blocks(x)
return x
class SubNet(nn.Module):
def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None:
super().__init__()
shortcut_scale_init_value = 0.5
self.save_memory = save_memory
self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)),
requires_grad=True) if shortcut_scale_init_value > 0 else None
self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)),
requires_grad=True) if shortcut_scale_init_value > 0 else None
self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)),
requires_grad=True) if shortcut_scale_init_value > 0 else None
self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)),
requires_grad=True) if shortcut_scale_init_value > 0 else None
self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates)
self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates)
self.level2 = Level(2, channels, layers, kernel_size,first_col, dp_rates)
self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates)
def _forward_nonreverse(self, *args):
x, c0, c1, c2, c3= args
c0 = (self.alpha0)*c0 + self.level0(x, c1)
c1 = (self.alpha1)*c1 + self.level1(c0, c2)
c2 = (self.alpha2)*c2 + self.level2(c1, c3)
c3 = (self.alpha3)*c3 + self.level3(c2, None)
return c0, c1, c2, c3
def _forward_reverse(self, *args):
local_funs = [self.level0, self.level1, self.level2, self.level3]
alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3]
_, c0, c1, c2, c3 = ReverseFunction.apply(
local_funs, alpha, *args)
return c0, c1, c2, c3
def forward(self, *args):
self._clamp_abs(self.alpha0.data, 1e-3)
self._clamp_abs(self.alpha1.data, 1e-3)
self._clamp_abs(self.alpha2.data, 1e-3)
self._clamp_abs(self.alpha3.data, 1e-3)
if self.save_memory:
return self._forward_reverse(*args)
else:
return self._forward_nonreverse(*args)
def _clamp_abs(self, data, value):
with torch.no_grad():
sign=data.sign()
data.abs_().clamp_(value)
data*=sign
class Classifier(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.LayerNorm(in_channels, eps=1e-6), # final norm layer
nn.Linear(in_channels, num_classes),
)
def forward(self, x):
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class FullNet(nn.Module):
def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size = 3, num_classes=1000, drop_path = 0.0, save_memory=True, inter_supv=True, head_init_scale=None) -> None:
super().__init__()
self.num_subnet = num_subnet
self.inter_supv = inter_supv
self.channels = channels
self.layers = layers
self.stem = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=4, stride=4),
LayerNorm(channels[0], eps=1e-6, data_format="channels_first")
)
dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))]
for i in range(num_subnet):
first_col = True if i == 0 else False
self.add_module(f'subnet{str(i)}', SubNet(
channels,layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory))
if not inter_supv:
self.cls = Classifier(in_channels=channels[-1], num_classes=num_classes)
else:
self.cls_blocks = nn.ModuleList([Classifier(in_channels=channels[-1], num_classes=num_classes) for _ in range(4) ])
if num_classes<=1000:
channels.reverse()
self.decoder_blocks = nn.ModuleList([Decoder(depth=[1,1,1,1], dim=channels, block_type=ConvNextBlock, kernel_size = 3) for _ in range(3) ])
else:
self.decoder_blocks = nn.ModuleList([SimDecoder(in_channel=channels[-1], encoder_stride=32) for _ in range(3) ])
self.apply(self._init_weights)
if head_init_scale:
print(f'Head_init_scale: {head_init_scale}')
self.cls.classifier._modules['1'].weight.data.mul_(head_init_scale)
self.cls.classifier._modules['1'].bias.data.mul_(head_init_scale)
def forward(self, x):
if self.inter_supv:
return self._forward_intermediate_supervision(x)
else:
c0, c1, c2, c3 = 0, 0, 0, 0
x = self.stem(x)
for i in range(self.num_subnet):
c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
return [self.cls(c3)], None
def _forward_intermediate_supervision(self, x):
x_cls_out = []
x_img_out = []
c0, c1, c2, c3 = 0, 0, 0, 0
interval = self.num_subnet//4
x = self.stem(x)
for i in range(self.num_subnet):
c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
if (i+1) % interval == 0:
x_cls_out.append(self.cls_blocks[i//interval](c3))
if i != self.num_subnet-1:
x_img_out.append(self.decoder_blocks[i//interval](c3))
return x_cls_out, x_img_out
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
##-------------------------------------- Tiny -----------------------------------------
def revcol_tiny(save_memory, inter_supv=True, drop_path=0.1, num_classes=1000, kernel_size = 3):
channels = [64, 128, 256, 512]
layers = [2, 2, 4, 2]
num_subnet = 4
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size)
##-------------------------------------- Small -----------------------------------------
def revcol_small(save_memory, inter_supv=True, drop_path=0.3, num_classes=1000, kernel_size = 3):
channels = [64, 128, 256, 512]
layers = [2, 2, 4, 2]
num_subnet = 8
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size)
##-------------------------------------- Base -----------------------------------------
def revcol_base(save_memory, inter_supv=True, drop_path=0.4, num_classes=1000, kernel_size = 3, head_init_scale=None):
channels = [72, 144, 288, 576]
layers = [1, 1, 3, 2]
num_subnet = 16
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)
##-------------------------------------- Large -----------------------------------------
def revcol_large(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None):
channels = [128, 256, 512, 1024]
layers = [1, 2, 6, 2]
num_subnet = 8
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)
##--------------------------------------Extra-Large -----------------------------------------
def revcol_xlarge(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None):
channels = [224, 448, 896, 1792]
layers = [1, 2, 6, 2]
num_subnet = 8
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)