TADBot / FER /models /vit_model.py
ryefoxlime's picture
FER alpha 0.1
499f0dc
raw
history blame
32.8 kB
"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.hub
from functools import partial
# import mat
# from vision_transformer.ir50 import Backbone
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.hub
from functools import partial
import math
from timm.layers import DropPath, to_2tuple, trunc_normal_
from timm.models import register_model
from timm.models.vision_transformer import _cfg, Mlp, Block
# from .ir50 import Backbone
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class BasicBlock(nn.Module):
__constants__ = ["downsample"]
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
norm_layer = nn.BatchNorm2d
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(
self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None
):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(256, 768, kernel_size=1)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# print(x.shape)
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
in_chans, # 输入token的dim
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_ratio=0.0,
proj_drop_ratio=0.0,
):
super(Attention, self).__init__()
self.num_heads = 8
self.img_chanel = in_chans + 1
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self, x):
x_img = x[:, : self.img_chanel, :]
# [batch_size, num_patches + 1, total_embed_dim]
B, N, C = x_img.shape
# print(C)
qkv = (
self.qkv(x_img)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
# k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_img = (attn @ v).transpose(1, 2).reshape(B, N, C)
x_img = self.proj(x_img)
x_img = self.proj_drop(x_img)
#
#
# # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
#
# # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
# # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
# attn = (q @ k.transpose(-2, -1)) * self.scale
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
#
# # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
# # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# x = self.proj(x)
# x = self.proj_drop(x)
return x_img
class AttentionBlock(nn.Module):
__constants__ = ["downsample"]
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(AttentionBlock, self).__init__()
norm_layer = nn.BatchNorm2d
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
# self.cbam = CBAM(planes, 16)
self.inplanes = inplanes
self.eca_block = eca_block()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
inplanes = self.inplanes
out = self.eca_block(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
in_chans,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.0,
attn_drop_ratio=0.0,
drop_path_ratio=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super(Block, self).__init__()
self.norm1 = norm_layer(dim)
self.img_chanel = in_chans + 1
self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
self.attn = Attention(
dim,
in_chans=in_chans,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio,
proj_drop_ratio=drop_ratio,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = (
DropPath(drop_path_ratio) if drop_path_ratio > 0.0 else nn.Identity()
)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop_ratio,
)
def forward(self, x):
# x = x + self.drop_path(self.attn(self.norm1(x)))
# x = x + self.drop_path(self.mlp(self.norm2(x)))
x_img = x
# [:, :self.img_chanel, :]
# x_lm = x[:, self.img_chanel:, :]
x_img = x_img + self.drop_path(self.attn(self.norm1(x)))
x = x_img + self.drop_path(self.mlp(self.norm2(x_img)))
#
# x_lm = x_lm + self.drop_path(self.attn_lm(self.norm3(x)))
# x_lm = x_lm + self.drop_path(self.mlp2(self.norm4(x_lm)))
# x = torch.cat((x_img, x_lm), dim=1)
# x = self.conv(x)
return x
class ClassificationHead(nn.Module):
def __init__(self, input_dim: int, target_dim: int):
super().__init__()
self.linear = torch.nn.Linear(input_dim, target_dim)
def forward(self, x):
x = x.view(x.size(0), -1)
y_hat = self.linear(x)
return y_hat
def load_pretrained_weights(model, checkpoint):
import collections
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = collections.OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
# If the pretrained state_dict was saved as nn.DataParallel,
# keys would contain "module.", which should be ignored.
if k.startswith("module."):
k = k[7:]
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
# new_state_dict.requires_grad = False
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
print("load_weight", len(matched_layers))
return model
class eca_block(nn.Module):
def __init__(self, channel=128, b=1, gamma=2):
super(eca_block, self).__init__()
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(
1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
return x * y.expand_as(x)
#
#
# class IR20(nn.Module):
# def __init__(self, img_size_=112, num_classes=7, layers=[2, 2, 2, 2]):
# super().__init__()
# norm_layer = nn.BatchNorm2d
# self.img_size = img_size_
# self._norm_layer = norm_layer
# self.num_classes = num_classes
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
# self.bn1 = norm_layer(64)
# self.relu = nn.ReLU(inplace=True)
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# # self.face_landback = MobileFaceNet([112, 112],136)
# # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
# self.layer1 = self._make_layer(BasicBlock, 64, 64, layers[0])
# self.layer2 = self._make_layer(BasicBlock, 64, 128, layers[1], stride=2)
# self.layer3 = self._make_layer(AttentionBlock, 128, 256, layers[2], stride=2)
# self.layer4 = self._make_layer(AttentionBlock, 256, 256, layers[3], stride=1)
# self.ir_back = Backbone(50, 51, 52, 0.0, 'ir')
# self.ir_layer = nn.Linear(1024, 512)
# # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\Pretrained_on_MSCeleb.pth.tar',
# # map_location=lambda storage, loc: storage)
# # ir_checkpoint = ir_checkpoint['state_dict']
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
# # checkpoint = torch.load('./checkpoint/Pretrained_on_MSCeleb.pth.tar')
# # pre_trained_dict = checkpoint['state_dict']
# # IR20.load_state_dict(ir_checkpoint, strict=False)
# # self.IR = load_pretrained_weights(IR, ir_checkpoint)
#
# def _make_layer(self, block, inplanes, planes, blocks, stride=1):
# norm_layer = self._norm_layer
# downsample = None
# if stride != 1 or inplanes != planes:
# downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))
# layers = []
# layers.append(block(inplanes, planes, stride, downsample))
# inplanes = planes
# for _ in range(1, blocks):
# layers.append(block(inplanes, planes))
# return nn.Sequential(*layers)
#
# def forward(self, x):
# x_ir = self.ir_back(x)
# # x_ir = self.ir_layer(x_ir)
# # print(x_ir.shape)
# # x = F.interpolate(x, size=112)
# # x = self.conv1(x)
# # x = self.bn1(x)
# # x = self.relu(x)
# # x = self.maxpool(x)
# #
# # x = self.layer1(x)
# # x = self.layer2(x)
# # x = self.layer3(x)
# # x = self.layer4(x)
# # print(x.shape)
# # print(x)
# out = x_ir
#
# return out
#
#
# class IR(nn.Module):
# def __init__(self, img_size_=112, num_classes=7):
# super().__init__()
# depth = 8
# # if type == "small":
# # depth = 4
# # if type == "base":
# # depth = 6
# # if type == "large":
# # depth = 8
#
# self.img_size = img_size_
# self.num_classes = num_classes
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
# # self.bn1 = norm_layer(64)
# self.relu = nn.ReLU(inplace=True)
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# # self.face_landback = MobileFaceNet([112, 112],136)
# # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
#
# # for param in self.face_landback.parameters():
# # param.requires_grad = False
#
# ###########################################################################333
#
# self.ir_back = IR20()
#
# # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\ir50.pth',
# # map_location=lambda storage, loc: storage)
# # # ir_checkpoint = ir_checkpoint["model"]
# # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
# # load_state_dict(checkpoint_model, strict=False)
# # self.ir_layer = nn.Linear(1024,512)
#
# #############################################################3
# #
# # self.pyramid_fuse = HyVisionTransformer(in_chans=49, q_chanel = 49, embed_dim=512,
# # depth=depth, num_heads=8, mlp_ratio=2.,
# # drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1)
#
# # self.se_block = SE_block(input_dim=512)
# self.head = ClassificationHead(input_dim=768, target_dim=self.num_classes)
#
# def forward(self, x):
# B_ = x.shape[0]
# # x_face = F.interpolate(x, size=112)
# # _, x_face = self.face_landback(x_face)
# # x_face = x_face.view(B_, -1, 49).transpose(1,2)
# ############### landmark x_face ([B, 49, 512])
# x_ir = self.ir_back(x)
# # print(x_ir.shape)
# # x_ir = self.ir_layer(x_ir)
# # print(x_ir.shape)
# ############### image x_ir ([B, 49, 512])
#
# # y_hat = self.pyramid_fuse(x_ir, x_face)
# # y_hat = self.se_block(y_hat)
# # y_feat = y_hat
#
# # out = self.head(x_ir)
#
# out = x_ir
# return out
class eca_block(nn.Module):
def __init__(self, channel=196, b=1, gamma=2):
super(eca_block, self).__init__()
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(
1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
return x * y.expand_as(x)
class SE_block(nn.Module):
def __init__(self, input_dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(input_dim, input_dim)
self.relu = nn.ReLU()
self.linear2 = torch.nn.Linear(input_dim, input_dim)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1 = self.linear1(x)
x1 = self.relu(x1)
x1 = self.linear2(x1)
x1 = self.sigmod(x1)
x = x * x1
return x
class VisionTransformer(nn.Module):
def __init__(
self,
img_size=14,
patch_size=14,
in_c=147,
num_classes=7,
embed_dim=768,
depth=6,
num_heads=8,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
representation_size=None,
distilled=False,
drop_ratio=0.0,
attn_drop_ratio=0.0,
drop_path_ratio=0.0,
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_c (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_ratio (float): dropout rate
attn_drop_ratio (float): attention dropout rate
drop_path_ratio (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
"""
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, in_c + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
self.se_block = SE_block(input_dim=embed_dim)
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768
)
num_patches = self.patch_embed.num_patches
self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = (
nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
)
# self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
# self.IR = IR()
self.eca_block = eca_block()
# self.ir_back = Backbone(50, 0.0, 'ir')
# ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
# # ir_checkpoint = ir_checkpoint["model"]
# self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
self.CON1 = nn.Conv2d(256, 768, kernel_size=1, stride=1, bias=False)
self.IRLinear1 = nn.Linear(1024, 768)
self.IRLinear2 = nn.Linear(768, 512)
self.eca_block = eca_block()
dpr = [
x.item() for x in torch.linspace(0, drop_path_ratio, depth)
] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
Block(
dim=embed_dim,
in_chans=in_c,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_ratio=drop_ratio,
attn_drop_ratio=attn_drop_ratio,
drop_path_ratio=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(
OrderedDict(
[
("fc", nn.Linear(embed_dim, representation_size)),
("act", nn.Tanh()),
]
)
)
else:
self.has_logits = False
self.pre_logits = nn.Identity()
# Classifier head(s)
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = (
nn.Linear(self.embed_dim, self.num_classes)
if num_classes > 0
else nn.Identity()
)
# Weight init
nn.init.trunc_normal_(self.pos_embed, std=0.02)
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(_init_vit_weights)
def forward_features(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim]
# x = self.patch_embed(x) # [B, 196, 768]
# [1, 1, 768] -> [B, 1, 768]
# print(x.shape)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
else:
x = torch.cat(
(cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1
)
# print(x.shape)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
# B = x.shape[0]
# print(x)
# x = self.eca_block(x)
# x = self.IR(x)
# x = eca_block(x)
# x = self.ir_back(x)
# print(x.shape)
# x = self.CON1(x)
# x = x.view(-1, 196, 768)
#
# # print(x.shape)
# # x = self.IRLinear1(x)
# # print(x)
# x_cls = torch.mean(x, 1).view(B, 1, -1)
# x = torch.cat((x_cls, x), dim=1)
# # print(x.shape)
# x = self.pos_drop(x + self.pos_embed)
# # print(x.shape)
# x = self.blocks(x)
# # print(x)
# x = self.norm(x)
# # print(x)
# # x1 = self.IRLinear2(x)
# x1 = x[:, 0, :]
# print(x1)
# print(x1.shape)
x = self.forward_features(x)
# # print(x.shape)
# if self.head_dist is not None:
# x, x_dist = self.head(x[0]), self.head_dist(x[1])
# if self.training and not torch.jit.is_scripting():
# # during inference, return the average of both classifier predictions
# return x, x_dist
# else:
# return (x + x_dist) / 2
# else:
# print(x.shape)
x = self.se_block(x)
x1 = self.head(x)
return x1
def _init_vit_weights(m):
"""
ViT weight initialization
:param m: module
"""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def vit_base_patch16_224(num_classes: int = 7):
"""
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
"""
model = VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=None,
num_classes=num_classes,
)
return model
def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
"""
model = VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes,
)
return model
def vit_base_patch32_224(num_classes: int = 1000):
"""
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
"""
model = VisionTransformer(
img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=None,
num_classes=num_classes,
)
return model
def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
"""
model = VisionTransformer(
img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
representation_size=768 if has_logits else None,
num_classes=num_classes,
)
return model
def vit_large_patch16_224(num_classes: int = 1000):
"""
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
"""
model = VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=None,
num_classes=num_classes,
)
return model
def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
"""
model = VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=1024 if has_logits else None,
num_classes=num_classes,
)
return model
def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
weights ported from official Google JAX impl:
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
"""
model = VisionTransformer(
img_size=224,
patch_size=32,
embed_dim=1024,
depth=24,
num_heads=16,
representation_size=1024 if has_logits else None,
num_classes=num_classes,
)
return model
def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
"""
ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: converted weights not currently available, too large for github release hosting.
"""
model = VisionTransformer(
img_size=224,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
representation_size=1280 if has_logits else None,
num_classes=num_classes,
)
return model