|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from models.modules import ConvNextBlock, Decoder, LayerNorm, SimDecoder, UpSampleConvnext |
|
import torch.distributed as dist |
|
from models.revcol_function import ReverseFunction |
|
from timm.models.layers import trunc_normal_ |
|
|
|
class Fusion(nn.Module): |
|
def __init__(self, level, channels, first_col) -> None: |
|
super().__init__() |
|
|
|
self.level = level |
|
self.first_col = first_col |
|
self.down = nn.Sequential( |
|
nn.Conv2d(channels[level-1], channels[level], kernel_size=2, stride=2), |
|
LayerNorm(channels[level], eps=1e-6, data_format="channels_first"), |
|
) if level in [1, 2, 3] else nn.Identity() |
|
if not first_col: |
|
self.up = UpSampleConvnext(1, channels[level+1], channels[level]) if level in [0, 1, 2] else nn.Identity() |
|
|
|
def forward(self, *args): |
|
|
|
c_down, c_up = args |
|
|
|
if self.first_col: |
|
x = self.down(c_down) |
|
return x |
|
|
|
if self.level == 3: |
|
x = self.down(c_down) |
|
else: |
|
x = self.up(c_up) + self.down(c_down) |
|
return x |
|
|
|
class Level(nn.Module): |
|
def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None: |
|
super().__init__() |
|
countlayer = sum(layers[:level]) |
|
expansion = 4 |
|
self.fusion = Fusion(level, channels, first_col) |
|
modules = [ConvNextBlock(channels[level], expansion*channels[level], channels[level], kernel_size = kernel_size, layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer+i]) for i in range(layers[level])] |
|
self.blocks = nn.Sequential(*modules) |
|
def forward(self, *args): |
|
x = self.fusion(*args) |
|
x = self.blocks(x) |
|
return x |
|
|
|
class SubNet(nn.Module): |
|
def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None: |
|
super().__init__() |
|
shortcut_scale_init_value = 0.5 |
|
self.save_memory = save_memory |
|
self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)), |
|
requires_grad=True) if shortcut_scale_init_value > 0 else None |
|
self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)), |
|
requires_grad=True) if shortcut_scale_init_value > 0 else None |
|
self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)), |
|
requires_grad=True) if shortcut_scale_init_value > 0 else None |
|
self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)), |
|
requires_grad=True) if shortcut_scale_init_value > 0 else None |
|
|
|
self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates) |
|
|
|
self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates) |
|
|
|
self.level2 = Level(2, channels, layers, kernel_size,first_col, dp_rates) |
|
|
|
self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates) |
|
|
|
def _forward_nonreverse(self, *args): |
|
x, c0, c1, c2, c3= args |
|
|
|
c0 = (self.alpha0)*c0 + self.level0(x, c1) |
|
c1 = (self.alpha1)*c1 + self.level1(c0, c2) |
|
c2 = (self.alpha2)*c2 + self.level2(c1, c3) |
|
c3 = (self.alpha3)*c3 + self.level3(c2, None) |
|
return c0, c1, c2, c3 |
|
|
|
def _forward_reverse(self, *args): |
|
|
|
local_funs = [self.level0, self.level1, self.level2, self.level3] |
|
alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3] |
|
_, c0, c1, c2, c3 = ReverseFunction.apply( |
|
local_funs, alpha, *args) |
|
|
|
return c0, c1, c2, c3 |
|
|
|
def forward(self, *args): |
|
|
|
self._clamp_abs(self.alpha0.data, 1e-3) |
|
self._clamp_abs(self.alpha1.data, 1e-3) |
|
self._clamp_abs(self.alpha2.data, 1e-3) |
|
self._clamp_abs(self.alpha3.data, 1e-3) |
|
|
|
if self.save_memory: |
|
return self._forward_reverse(*args) |
|
else: |
|
return self._forward_nonreverse(*args) |
|
|
|
def _clamp_abs(self, data, value): |
|
with torch.no_grad(): |
|
sign=data.sign() |
|
data.abs_().clamp_(value) |
|
data*=sign |
|
|
|
|
|
class Classifier(nn.Module): |
|
def __init__(self, in_channels, num_classes): |
|
super().__init__() |
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
self.classifier = nn.Sequential( |
|
nn.LayerNorm(in_channels, eps=1e-6), |
|
nn.Linear(in_channels, num_classes), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.avgpool(x) |
|
x = x.view(x.size(0), -1) |
|
x = self.classifier(x) |
|
return x |
|
|
|
class FullNet(nn.Module): |
|
def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size = 3, num_classes=1000, drop_path = 0.0, save_memory=True, inter_supv=True, head_init_scale=None) -> None: |
|
super().__init__() |
|
self.num_subnet = num_subnet |
|
self.inter_supv = inter_supv |
|
self.channels = channels |
|
self.layers = layers |
|
|
|
self.stem = nn.Sequential( |
|
nn.Conv2d(3, channels[0], kernel_size=4, stride=4), |
|
LayerNorm(channels[0], eps=1e-6, data_format="channels_first") |
|
) |
|
|
|
dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))] |
|
for i in range(num_subnet): |
|
first_col = True if i == 0 else False |
|
self.add_module(f'subnet{str(i)}', SubNet( |
|
channels,layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory)) |
|
|
|
if not inter_supv: |
|
self.cls = Classifier(in_channels=channels[-1], num_classes=num_classes) |
|
else: |
|
self.cls_blocks = nn.ModuleList([Classifier(in_channels=channels[-1], num_classes=num_classes) for _ in range(4) ]) |
|
if num_classes<=1000: |
|
channels.reverse() |
|
self.decoder_blocks = nn.ModuleList([Decoder(depth=[1,1,1,1], dim=channels, block_type=ConvNextBlock, kernel_size = 3) for _ in range(3) ]) |
|
else: |
|
self.decoder_blocks = nn.ModuleList([SimDecoder(in_channel=channels[-1], encoder_stride=32) for _ in range(3) ]) |
|
|
|
self.apply(self._init_weights) |
|
|
|
if head_init_scale: |
|
print(f'Head_init_scale: {head_init_scale}') |
|
self.cls.classifier._modules['1'].weight.data.mul_(head_init_scale) |
|
self.cls.classifier._modules['1'].bias.data.mul_(head_init_scale) |
|
|
|
|
|
def forward(self, x): |
|
|
|
if self.inter_supv: |
|
return self._forward_intermediate_supervision(x) |
|
else: |
|
c0, c1, c2, c3 = 0, 0, 0, 0 |
|
x = self.stem(x) |
|
for i in range(self.num_subnet): |
|
c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3) |
|
return [self.cls(c3)], None |
|
|
|
def _forward_intermediate_supervision(self, x): |
|
x_cls_out = [] |
|
x_img_out = [] |
|
c0, c1, c2, c3 = 0, 0, 0, 0 |
|
interval = self.num_subnet//4 |
|
|
|
x = self.stem(x) |
|
for i in range(self.num_subnet): |
|
c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3) |
|
if (i+1) % interval == 0: |
|
x_cls_out.append(self.cls_blocks[i//interval](c3)) |
|
if i != self.num_subnet-1: |
|
x_img_out.append(self.decoder_blocks[i//interval](c3)) |
|
|
|
return x_cls_out, x_img_out |
|
|
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Conv2d): |
|
trunc_normal_(module.weight, std=.02) |
|
nn.init.constant_(module.bias, 0) |
|
elif isinstance(module, nn.Linear): |
|
trunc_normal_(module.weight, std=.02) |
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
|
|
def revcol_tiny(save_memory, inter_supv=True, drop_path=0.1, num_classes=1000, kernel_size = 3): |
|
channels = [64, 128, 256, 512] |
|
layers = [2, 2, 4, 2] |
|
num_subnet = 4 |
|
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size) |
|
|
|
|
|
|
|
def revcol_small(save_memory, inter_supv=True, drop_path=0.3, num_classes=1000, kernel_size = 3): |
|
channels = [64, 128, 256, 512] |
|
layers = [2, 2, 4, 2] |
|
num_subnet = 8 |
|
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size) |
|
|
|
|
|
|
|
def revcol_base(save_memory, inter_supv=True, drop_path=0.4, num_classes=1000, kernel_size = 3, head_init_scale=None): |
|
channels = [72, 144, 288, 576] |
|
layers = [1, 1, 3, 2] |
|
num_subnet = 16 |
|
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size) |
|
|
|
|
|
|
|
|
|
def revcol_large(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None): |
|
channels = [128, 256, 512, 1024] |
|
layers = [1, 2, 6, 2] |
|
num_subnet = 8 |
|
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size) |
|
|
|
|
|
def revcol_xlarge(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None): |
|
channels = [224, 448, 896, 1792] |
|
layers = [1, 2, 6, 2] |
|
num_subnet = 8 |
|
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size) |
|
|
|
|