MambaVision-B-1K / modeling_mambavision.py
ahatamiz's picture
Update modeling_mambavision.py
34dd0a1 verified
#!/usr/bin/env python3
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import torch
import torch.nn as nn
from timm.models.registry import register_model
import math
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
from timm.models._builder import resolve_pretrained_cfg
try:
from timm.models._builder import _update_default_kwargs as update_args
except:
from timm.models._builder import _update_default_model_kwargs as update_args
from timm.models.vision_transformer import Mlp, PatchEmbed
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from einops import rearrange, repeat
from transformers import PreTrainedModel
from .configuration_mambavision import MambaVisionConfig
def _cfg(url='', **kwargs):
return {'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': None,
'crop_pct': 0.875,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
**kwargs
}
default_cfgs = {
'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
crop_pct=0.98,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
crop_pct=0.93,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center')
}
def window_partition(x, window_size):
"""
Args:
x: (B, C, H, W)
window_size: window size
h_w: Height of window
w_w: Width of window
Returns:
local window features (num_windows*B, window_size*window_size, C)
"""
B, C, H, W = x.shape
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: local window features (num_windows*B, window_size, window_size, C)
window_size: Window size
H: Height of image
W: Width of image
Returns:
x: (B, C, H, W)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
return x
def _load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
if len(err_msg) > 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def _load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = torch.load(filename, map_location=map_location)
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
_load_state_dict(model, state_dict, strict, logger)
return checkpoint
class Downsample(nn.Module):
"""
Down-sampling block"
"""
def __init__(self,
dim,
keep_dim=False,
):
"""
Args:
dim: feature size dimension.
norm_layer: normalization layer.
keep_dim: bool argument for maintaining the resolution.
"""
super().__init__()
if keep_dim:
dim_out = dim
else:
dim_out = 2 * dim
self.reduction = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
)
def forward(self, x):
x = self.reduction(x)
return x
class PatchEmbed(nn.Module):
"""
Patch embedding block"
"""
def __init__(self, in_chans=3, in_dim=64, dim=96):
"""
Args:
in_chans: number of input channels.
dim: feature size dimension.
"""
# in_dim = 1
super().__init__()
self.proj = nn.Identity()
self.conv_down = nn.Sequential(
nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(in_dim, eps=1e-4),
nn.ReLU(),
nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(dim, eps=1e-4),
nn.ReLU()
)
def forward(self, x):
x = self.proj(x)
x = self.conv_down(x)
return x
class ConvBlock(nn.Module):
def __init__(self, dim,
drop_path=0.,
layer_scale=None,
kernel_size=3):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
self.act1 = nn.GELU(approximate= 'tanh')
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
self.layer_scale = layer_scale
if layer_scale is not None and type(layer_scale) in [int, float]:
self.g = nn.Parameter(layer_scale * torch.ones(dim))
self.layer_scale = True
else:
self.layer_scale = False
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.norm2(x)
if self.layer_scale:
x = x * self.g.view(1, -1, 1, 1)
x = input + self.drop_path(x)
return x
class MambaVisionMixer(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True,
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
self.x_proj = nn.Linear(
self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
self.dt_proj.bias._no_reinit = True
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner//2,
).contiguous()
A_log = torch.log(A)
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.conv1d_x = nn.Conv1d(
in_channels=self.d_inner//2,
out_channels=self.d_inner//2,
bias=conv_bias//2,
kernel_size=d_conv,
groups=self.d_inner//2,
**factory_kwargs,
)
self.conv1d_z = nn.Conv1d(
in_channels=self.d_inner//2,
out_channels=self.d_inner//2,
bias=conv_bias//2,
kernel_size=d_conv,
groups=self.d_inner//2,
**factory_kwargs,
)
def forward(self, hidden_states):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
_, seqlen, _ = hidden_states.shape
xz = self.in_proj(hidden_states)
xz = rearrange(xz, "b l d -> b d l")
x, z = xz.chunk(2, dim=1)
A = -torch.exp(self.A_log.float())
x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
y = selective_scan_fn(x,
dt,
A,
B,
C,
self.D.float(),
z=None,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=None)
y = torch.cat([y, z], dim=1)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = True
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
counter,
transformer_blocks,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Mlp_block=Mlp,
layer_scale=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
if counter in transformer_blocks:
self.mixer = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
else:
self.mixer = MambaVisionMixer(d_model=dim,
d_state=8,
d_conv=3,
expand=1
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
self.g_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
self.g_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
def forward(self, x):
x = x + self.drop_path(self.g_1 * self.mixer(self.norm1(x)))
x = x + self.drop_path(self.g_2 * self.mlp(self.norm2(x)))
return x
class MambaVisionLayer(nn.Module):
"""
MambaVision layer"
"""
def __init__(self,
dim,
depth,
num_heads,
window_size,
conv=False,
downsample=True,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
layer_scale=None,
layer_scale_conv=None,
transformer_blocks = [],
):
"""
Args:
dim: feature size dimension.
depth: number of layers in each stage.
window_size: window size in each stage.
conv: bool argument for conv stage flag.
downsample: bool argument for down-sampling.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop: dropout rate.
attn_drop: attention dropout rate.
drop_path: drop path rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
layer_scale_conv: conv layer scaling coefficient.
transformer_blocks: list of transformer blocks.
"""
super().__init__()
self.conv = conv
self.transformer_block = False
if conv:
self.blocks = nn.ModuleList([ConvBlock(dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale_conv)
for i in range(depth)])
self.transformer_block = False
else:
self.transformer_block = True
self.blocks = nn.ModuleList([Block(dim=dim,
counter=i,
transformer_blocks=transformer_blocks,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale)
for i in range(depth)])
self.transformer_block = True
self.downsample = None if not downsample else Downsample(dim=dim)
self.do_gt = False
self.window_size = window_size
def forward(self, x):
_, _, H, W = x.shape
if self.transformer_block:
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
if pad_r > 0 or pad_b > 0:
x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
_, _, Hp, Wp = x.shape
else:
Hp, Wp = H, W
x = window_partition(x, self.window_size)
for _, blk in enumerate(self.blocks):
x = blk(x)
if self.transformer_block:
x = window_reverse(x, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :, :H, :W].contiguous()
if self.downsample is None:
return x, x
return self.downsample(x), x
class MambaVision(nn.Module):
"""
MambaVision,
"""
def __init__(self,
dim,
in_dim,
depths,
window_size,
mlp_ratio,
num_heads,
drop_path_rate=0.2,
in_chans=3,
num_classes=1000,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
layer_scale=None,
layer_scale_conv=None,
**kwargs):
"""
Args:
dim: feature size dimension.
depths: number of layers in each stage.
window_size: window size in each stage.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
drop_path_rate: drop path rate.
in_chans: number of input channels.
num_classes: number of classes.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
layer_scale_conv: conv layer scaling coefficient.
"""
super().__init__()
num_features = int(dim * 2 ** (len(depths) - 1))
self.num_classes = num_classes
self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.levels = nn.ModuleList()
for i in range(len(depths)):
conv = True if (i == 0 or i == 1) else False
level = MambaVisionLayer(dim=int(dim * 2 ** i),
depth=depths[i],
num_heads=num_heads[i],
window_size=window_size[i],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
conv=conv,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
downsample=(i < 3),
layer_scale=layer_scale,
layer_scale_conv=layer_scale_conv,
transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
)
self.levels.append(level)
self.norm = nn.BatchNorm2d(num_features)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, LayerNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'rpb'}
def forward_features(self, x):
x = self.patch_embed(x)
outs = []
for level in self.levels:
x, xo = level(x)
outs.append(xo)
x = self.norm(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x, outs
def forward(self, x):
x, outs = self.forward_features(x)
x = self.head(x)
return x
def _load_state_dict(self,
pretrained,
strict: bool = False):
_load_checkpoint(self,
pretrained,
strict=strict)
class MambaVisionModel(PreTrainedModel):
config_class = MambaVisionConfig
def __init__(self, config):
super().__init__(config)
self.model = MambaVision(
depths=config.depths,
num_heads=config.num_heads,
window_size=config.window_size,
dim=config.dim,
in_dim=config.in_dim,
mlp_ratio=config.mlp_ratio,
layer_scale=config.layer_scale,
layer_scale_conv=config.layer_scale_conv
)
def forward(self, tensor):
return self.model.forward_features(tensor)
class MambaVisionModelForImageClassification(PreTrainedModel):
config_class = MambaVisionConfig
def __init__(self, config):
super().__init__(config)
self.model = MambaVision(
depths=config.depths,
num_heads=config.num_heads,
window_size=config.window_size,
dim=config.dim,
in_dim=config.in_dim,
mlp_ratio=config.mlp_ratio,
layer_scale=config.layer_scale,
layer_scale_conv=config.layer_scale_conv
)
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}