""" original code from rwightman: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ from functools import partial from collections import OrderedDict import torch import torch.nn as nn import torch import torch.nn as nn import torch.nn.functional as F import torch.hub from functools import partial # import mat # from vision_transformer.ir50 import Backbone import torch import torch.nn as nn import torch.nn.functional as F import torch.hub from functools import partial import math from timm.layers import DropPath, to_2tuple, trunc_normal_ from timm.models import register_model from timm.models.vision_transformer import _cfg, Mlp, Block # from .ir50 import Backbone def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) def drop_path(x, drop_prob: float = 0.0, training: bool = False): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class BasicBlock(nn.Module): __constants__ = ["downsample"] def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() norm_layer = nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__( self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None ): super().__init__() img_size = (img_size, img_size) patch_size = (patch_size, patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d(256, 768, kernel_size=1) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape # assert H == self.img_size[0] and W == self.img_size[1], \ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # print(x.shape) # flatten: [B, C, H, W] -> [B, C, HW] # transpose: [B, C, HW] -> [B, HW, C] x = self.proj(x).flatten(2).transpose(1, 2) x = self.norm(x) return x class Attention(nn.Module): def __init__( self, dim, in_chans, # 输入token的dim num_heads=8, qkv_bias=False, qk_scale=None, attn_drop_ratio=0.0, proj_drop_ratio=0.0, ): super(Attention, self).__init__() self.num_heads = 8 self.img_chanel = in_chans + 1 head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop_ratio) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop_ratio) def forward(self, x): x_img = x[:, : self.img_chanel, :] # [batch_size, num_patches + 1, total_embed_dim] B, N, C = x_img.shape # print(C) qkv = ( self.qkv(x_img) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] # k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple) # q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x_img = (attn @ v).transpose(1, 2).reshape(B, N, C) x_img = self.proj(x_img) x_img = self.proj_drop(x_img) # # # # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] # # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] # # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # # [batch_size, num_heads, num_patches + 1, embed_dim_per_head] # q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # # # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] # # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] # attn = (q @ k.transpose(-2, -1)) * self.scale # attn = attn.softmax(dim=-1) # attn = self.attn_drop(attn) # # # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head] # # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head] # # reshape: -> [batch_size, num_patches + 1, total_embed_dim] # x = (attn @ v).transpose(1, 2).reshape(B, N, C) # x = self.proj(x) # x = self.proj_drop(x) return x_img class AttentionBlock(nn.Module): __constants__ = ["downsample"] def __init__(self, inplanes, planes, stride=1, downsample=None): super(AttentionBlock, self).__init__() norm_layer = nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride # self.cbam = CBAM(planes, 16) self.inplanes = inplanes self.eca_block = eca_block() def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) inplanes = self.inplanes out = self.eca_block(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Block(nn.Module): def __init__( self, dim, in_chans, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_ratio=0.0, attn_drop_ratio=0.0, drop_path_ratio=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super(Block, self).__init__() self.norm1 = norm_layer(dim) self.img_chanel = in_chans + 1 self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1) self.attn = Attention( dim, in_chans=in_chans, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = ( DropPath(drop_path_ratio) if drop_path_ratio > 0.0 else nn.Identity() ) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio, ) def forward(self, x): # x = x + self.drop_path(self.attn(self.norm1(x))) # x = x + self.drop_path(self.mlp(self.norm2(x))) x_img = x # [:, :self.img_chanel, :] # x_lm = x[:, self.img_chanel:, :] x_img = x_img + self.drop_path(self.attn(self.norm1(x))) x = x_img + self.drop_path(self.mlp(self.norm2(x_img))) # # x_lm = x_lm + self.drop_path(self.attn_lm(self.norm3(x))) # x_lm = x_lm + self.drop_path(self.mlp2(self.norm4(x_lm))) # x = torch.cat((x_img, x_lm), dim=1) # x = self.conv(x) return x class ClassificationHead(nn.Module): def __init__(self, input_dim: int, target_dim: int): super().__init__() self.linear = torch.nn.Linear(input_dim, target_dim) def forward(self, x): x = x.view(x.size(0), -1) y_hat = self.linear(x) return y_hat def load_pretrained_weights(model, checkpoint): import collections if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint model_dict = model.state_dict() new_state_dict = collections.OrderedDict() matched_layers, discarded_layers = [], [] for k, v in state_dict.items(): # If the pretrained state_dict was saved as nn.DataParallel, # keys would contain "module.", which should be ignored. if k.startswith("module."): k = k[7:] if k in model_dict and model_dict[k].size() == v.size(): new_state_dict[k] = v matched_layers.append(k) else: discarded_layers.append(k) # new_state_dict.requires_grad = False model_dict.update(new_state_dict) model.load_state_dict(model_dict) print("load_weight", len(matched_layers)) return model class eca_block(nn.Module): def __init__(self, channel=128, b=1, gamma=2): super(eca_block, self).__init__() kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d( 1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False ) self.sigmoid = nn.Sigmoid() def forward(self, x): y = self.avg_pool(x) y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) y = self.sigmoid(y) return x * y.expand_as(x) # # # class IR20(nn.Module): # def __init__(self, img_size_=112, num_classes=7, layers=[2, 2, 2, 2]): # super().__init__() # norm_layer = nn.BatchNorm2d # self.img_size = img_size_ # self._norm_layer = norm_layer # self.num_classes = num_classes # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # self.bn1 = norm_layer(64) # self.relu = nn.ReLU(inplace=True) # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # # self.face_landback = MobileFaceNet([112, 112],136) # # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage) # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict']) # self.layer1 = self._make_layer(BasicBlock, 64, 64, layers[0]) # self.layer2 = self._make_layer(BasicBlock, 64, 128, layers[1], stride=2) # self.layer3 = self._make_layer(AttentionBlock, 128, 256, layers[2], stride=2) # self.layer4 = self._make_layer(AttentionBlock, 256, 256, layers[3], stride=1) # self.ir_back = Backbone(50, 51, 52, 0.0, 'ir') # self.ir_layer = nn.Linear(1024, 512) # # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\Pretrained_on_MSCeleb.pth.tar', # # map_location=lambda storage, loc: storage) # # ir_checkpoint = ir_checkpoint['state_dict'] # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict']) # # checkpoint = torch.load('./checkpoint/Pretrained_on_MSCeleb.pth.tar') # # pre_trained_dict = checkpoint['state_dict'] # # IR20.load_state_dict(ir_checkpoint, strict=False) # # self.IR = load_pretrained_weights(IR, ir_checkpoint) # # def _make_layer(self, block, inplanes, planes, blocks, stride=1): # norm_layer = self._norm_layer # downsample = None # if stride != 1 or inplanes != planes: # downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes)) # layers = [] # layers.append(block(inplanes, planes, stride, downsample)) # inplanes = planes # for _ in range(1, blocks): # layers.append(block(inplanes, planes)) # return nn.Sequential(*layers) # # def forward(self, x): # x_ir = self.ir_back(x) # # x_ir = self.ir_layer(x_ir) # # print(x_ir.shape) # # x = F.interpolate(x, size=112) # # x = self.conv1(x) # # x = self.bn1(x) # # x = self.relu(x) # # x = self.maxpool(x) # # # # x = self.layer1(x) # # x = self.layer2(x) # # x = self.layer3(x) # # x = self.layer4(x) # # print(x.shape) # # print(x) # out = x_ir # # return out # # # class IR(nn.Module): # def __init__(self, img_size_=112, num_classes=7): # super().__init__() # depth = 8 # # if type == "small": # # depth = 4 # # if type == "base": # # depth = 6 # # if type == "large": # # depth = 8 # # self.img_size = img_size_ # self.num_classes = num_classes # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # # self.bn1 = norm_layer(64) # self.relu = nn.ReLU(inplace=True) # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # # self.face_landback = MobileFaceNet([112, 112],136) # # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage) # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict']) # # # for param in self.face_landback.parameters(): # # param.requires_grad = False # # ###########################################################################333 # # self.ir_back = IR20() # # # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\ir50.pth', # # map_location=lambda storage, loc: storage) # # # ir_checkpoint = ir_checkpoint["model"] # # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint) # # load_state_dict(checkpoint_model, strict=False) # # self.ir_layer = nn.Linear(1024,512) # # #############################################################3 # # # # self.pyramid_fuse = HyVisionTransformer(in_chans=49, q_chanel = 49, embed_dim=512, # # depth=depth, num_heads=8, mlp_ratio=2., # # drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1) # # # self.se_block = SE_block(input_dim=512) # self.head = ClassificationHead(input_dim=768, target_dim=self.num_classes) # # def forward(self, x): # B_ = x.shape[0] # # x_face = F.interpolate(x, size=112) # # _, x_face = self.face_landback(x_face) # # x_face = x_face.view(B_, -1, 49).transpose(1,2) # ############### landmark x_face ([B, 49, 512]) # x_ir = self.ir_back(x) # # print(x_ir.shape) # # x_ir = self.ir_layer(x_ir) # # print(x_ir.shape) # ############### image x_ir ([B, 49, 512]) # # # y_hat = self.pyramid_fuse(x_ir, x_face) # # y_hat = self.se_block(y_hat) # # y_feat = y_hat # # # out = self.head(x_ir) # # out = x_ir # return out class eca_block(nn.Module): def __init__(self, channel=196, b=1, gamma=2): super(eca_block, self).__init__() kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d( 1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False ) self.sigmoid = nn.Sigmoid() def forward(self, x): y = self.avg_pool(x) y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) y = self.sigmoid(y) return x * y.expand_as(x) class SE_block(nn.Module): def __init__(self, input_dim: int): super().__init__() self.linear1 = torch.nn.Linear(input_dim, input_dim) self.relu = nn.ReLU() self.linear2 = torch.nn.Linear(input_dim, input_dim) self.sigmod = nn.Sigmoid() def forward(self, x): x1 = self.linear1(x) x1 = self.relu(x1) x1 = self.linear2(x1) x1 = self.sigmod(x1) x = x * x1 return x class VisionTransformer(nn.Module): def __init__( self, img_size=14, patch_size=14, in_c=147, num_classes=7, embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.0, attn_drop_ratio=0.0, drop_path_ratio=0.0, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_c (int): number of input channels num_classes (int): number of classes for classification head embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True qk_scale (float): override default qk scale of head_dim ** -0.5 if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set distilled (bool): model includes a distillation token and head as in DeiT models drop_ratio (float): dropout rate attn_drop_ratio (float): attention dropout rate drop_path_ratio (float): stochastic depth rate embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer """ super(VisionTransformer, self).__init__() self.num_classes = num_classes self.num_features = self.embed_dim = ( embed_dim # num_features for consistency with other models ) self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, in_c + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_ratio) self.se_block = SE_block(input_dim=embed_dim) self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768 ) num_patches = self.patch_embed.num_patches self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = ( nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None ) # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_ratio) # self.IR = IR() self.eca_block = eca_block() # self.ir_back = Backbone(50, 0.0, 'ir') # ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage) # # ir_checkpoint = ir_checkpoint["model"] # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint) self.CON1 = nn.Conv2d(256, 768, kernel_size=1, stride=1, bias=False) self.IRLinear1 = nn.Linear(1024, 768) self.IRLinear2 = nn.Linear(768, 512) self.eca_block = eca_block() dpr = [ x.item() for x in torch.linspace(0, drop_path_ratio, depth) ] # stochastic depth decay rule self.blocks = nn.Sequential( *[ Block( dim=embed_dim, in_chans=in_c, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ) for i in range(depth) ] ) self.norm = norm_layer(embed_dim) # Representation layer if representation_size and not distilled: self.has_logits = True self.num_features = representation_size self.pre_logits = nn.Sequential( OrderedDict( [ ("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh()), ] ) ) else: self.has_logits = False self.pre_logits = nn.Identity() # Classifier head(s) # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = None if distilled: self.head_dist = ( nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() ) # Weight init nn.init.trunc_normal_(self.pos_embed, std=0.02) if self.dist_token is not None: nn.init.trunc_normal_(self.dist_token, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) self.apply(_init_vit_weights) def forward_features(self, x): # [B, C, H, W] -> [B, num_patches, embed_dim] # x = self.patch_embed(x) # [B, 196, 768] # [1, 1, 768] -> [B, 1, 768] # print(x.shape) cls_token = self.cls_token.expand(x.shape[0], -1, -1) if self.dist_token is None: x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] else: x = torch.cat( (cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1 ) # print(x.shape) x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) if self.dist_token is None: return self.pre_logits(x[:, 0]) else: return x[:, 0], x[:, 1] def forward(self, x): # B = x.shape[0] # print(x) # x = self.eca_block(x) # x = self.IR(x) # x = eca_block(x) # x = self.ir_back(x) # print(x.shape) # x = self.CON1(x) # x = x.view(-1, 196, 768) # # # print(x.shape) # # x = self.IRLinear1(x) # # print(x) # x_cls = torch.mean(x, 1).view(B, 1, -1) # x = torch.cat((x_cls, x), dim=1) # # print(x.shape) # x = self.pos_drop(x + self.pos_embed) # # print(x.shape) # x = self.blocks(x) # # print(x) # x = self.norm(x) # # print(x) # # x1 = self.IRLinear2(x) # x1 = x[:, 0, :] # print(x1) # print(x1.shape) x = self.forward_features(x) # # print(x.shape) # if self.head_dist is not None: # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # if self.training and not torch.jit.is_scripting(): # # during inference, return the average of both classifier predictions # return x, x_dist # else: # return (x + x_dist) / 2 # else: # print(x.shape) x = self.se_block(x) x1 = self.head(x) return x1 def _init_vit_weights(m): """ ViT weight initialization :param m: module """ if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.01) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.zeros_(m.bias) nn.init.ones_(m.weight) def vit_base_patch16_224(num_classes: int = 7): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f """ model = VisionTransformer( img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=None, num_classes=num_classes, ) return model def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth """ model = VisionTransformer( img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768 if has_logits else None, num_classes=num_classes, ) return model def vit_base_patch32_224(num_classes: int = 1000): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl """ model = VisionTransformer( img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=None, num_classes=num_classes, ) return model def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth """ model = VisionTransformer( img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768 if has_logits else None, num_classes=num_classes, ) return model def vit_large_patch16_224(num_classes: int = 1000): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8 """ model = VisionTransformer( img_size=224, patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=None, num_classes=num_classes, ) return model def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth """ model = VisionTransformer( img_size=224, patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024 if has_logits else None, num_classes=num_classes, ) return model def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth """ model = VisionTransformer( img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024 if has_logits else None, num_classes=num_classes, ) return model def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: converted weights not currently available, too large for github release hosting. """ model = VisionTransformer( img_size=224, patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280 if has_logits else None, num_classes=num_classes, ) return model