File size: 6,710 Bytes
b9425fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# --------------------------------------------------------
# Reversible Column Networks
# Copyright (c) 2022 Megvii Inc.
# Licensed under The Apache License 2.0 [see LICENSE for details]
# Written by Yuxuan Cai
# --------------------------------------------------------
import imp
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
class UpSampleConvnext(nn.Module):
def __init__(self, ratio, inchannel, outchannel):
super().__init__()
self.ratio = ratio
self.channel_reschedule = nn.Sequential(
# LayerNorm(inchannel, eps=1e-6, data_format="channels_last"),
nn.Linear(inchannel, outchannel),
LayerNorm(outchannel, eps=1e-6, data_format="channels_last"))
self.upsample = nn.Upsample(scale_factor=2**ratio, mode='nearest')
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.channel_reschedule(x)
x = x = x.permute(0, 3, 1, 2)
return self.upsample(x)
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first", elementwise_affine = True):
super().__init__()
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
if self.elementwise_affine:
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class ConvNextBlock(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, in_channel, hidden_dim, out_channel, kernel_size=3, layer_scale_init_value=1e-6, drop_path= 0.0):
super().__init__()
self.dwconv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=in_channel) # depthwise conv
self.norm = nn.LayerNorm(in_channel, eps=1e-6)
self.pwconv1 = nn.Linear(in_channel, hidden_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(hidden_dim, out_channel)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channel)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
# print(f"x min: {x.min()}, x max: {x.max()}, input min: {input.min()}, input max: {input.max()}, x mean: {x.mean()}, x var: {x.var()}, ratio: {torch.sum(x>8)/x.numel()}")
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class Decoder(nn.Module):
def __init__(self, depth=[2,2,2,2], dim=[112, 72, 40, 24], block_type = None, kernel_size = 3) -> None:
super().__init__()
self.depth = depth
self.dim = dim
self.block_type = block_type
self._build_decode_layer(dim, depth, kernel_size)
self.projback = nn.Sequential(
nn.Conv2d(
in_channels=dim[-1],
out_channels=4 ** 2 * 3, kernel_size=1),
nn.PixelShuffle(4),
)
def _build_decode_layer(self, dim, depth, kernel_size):
normal_layers = nn.ModuleList()
upsample_layers = nn.ModuleList()
proj_layers = nn.ModuleList()
norm_layer = LayerNorm
for i in range(1, len(dim)):
module = [self.block_type(dim[i], dim[i], dim[i], kernel_size) for _ in range(depth[i])]
normal_layers.append(nn.Sequential(*module))
upsample_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
proj_layers.append(nn.Sequential(
nn.Conv2d(dim[i-1], dim[i], 1, 1),
norm_layer(dim[i]),
nn.GELU()
))
self.normal_layers = normal_layers
self.upsample_layers = upsample_layers
self.proj_layers = proj_layers
def _forward_stage(self, stage, x):
x = self.proj_layers[stage](x)
x = self.upsample_layers[stage](x)
return self.normal_layers[stage](x)
def forward(self, c3):
x = self._forward_stage(0, c3) #14
x = self._forward_stage(1, x) #28
x = self._forward_stage(2, x) #56
x = self.projback(x)
return x
class SimDecoder(nn.Module):
def __init__(self, in_channel, encoder_stride) -> None:
super().__init__()
self.projback = nn.Sequential(
LayerNorm(in_channel),
nn.Conv2d(
in_channels=in_channel,
out_channels=encoder_stride ** 2 * 3, kernel_size=1),
nn.PixelShuffle(encoder_stride),
)
def forward(self, c3):
return self.projback(c3) |