|
""" |
|
original code from rwightman: |
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py |
|
""" |
|
|
|
from functools import partial |
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.hub |
|
from functools import partial |
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.hub |
|
from functools import partial |
|
import math |
|
|
|
from timm.layers import DropPath, to_2tuple, trunc_normal_ |
|
from timm.models import register_model |
|
from timm.models.vision_transformer import _cfg, Mlp, Block |
|
|
|
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): |
|
"""3x3 convolution with padding""" |
|
return nn.Conv2d( |
|
in_planes, |
|
out_planes, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=dilation, |
|
groups=groups, |
|
bias=False, |
|
dilation=dilation, |
|
) |
|
|
|
|
|
def conv1x1(in_planes, out_planes, stride=1): |
|
"""1x1 convolution""" |
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
|
|
|
|
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False): |
|
""" |
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
|
'survival rate' as the argument. |
|
""" |
|
if drop_prob == 0.0 or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (x.shape[0],) + (1,) * ( |
|
x.ndim - 1 |
|
) |
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
random_tensor.floor_() |
|
output = x.div(keep_prob) * random_tensor |
|
return output |
|
|
|
|
|
class BasicBlock(nn.Module): |
|
__constants__ = ["downsample"] |
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None): |
|
super(BasicBlock, self).__init__() |
|
norm_layer = nn.BatchNorm2d |
|
self.conv1 = conv3x3(inplanes, planes, stride) |
|
self.bn1 = norm_layer(planes) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv2 = conv3x3(planes, planes) |
|
self.bn2 = norm_layer(planes) |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
|
|
if self.downsample is not None: |
|
identity = self.downsample(x) |
|
|
|
out += identity |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class DropPath(nn.Module): |
|
""" |
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
""" |
|
|
|
def __init__(self, drop_prob=None): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
|
|
def forward(self, x): |
|
return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
""" |
|
2D Image to Patch Embedding |
|
""" |
|
|
|
def __init__( |
|
self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None |
|
): |
|
super().__init__() |
|
img_size = (img_size, img_size) |
|
patch_size = (patch_size, patch_size) |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) |
|
self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
|
|
self.proj = nn.Conv2d(256, 768, kernel_size=1) |
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2) |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
in_chans, |
|
num_heads=8, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
attn_drop_ratio=0.0, |
|
proj_drop_ratio=0.0, |
|
): |
|
super(Attention, self).__init__() |
|
self.num_heads = 8 |
|
self.img_chanel = in_chans + 1 |
|
head_dim = dim // num_heads |
|
self.scale = head_dim**-0.5 |
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop_ratio) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop_ratio) |
|
|
|
def forward(self, x): |
|
x_img = x[:, : self.img_chanel, :] |
|
|
|
B, N, C = x_img.shape |
|
|
|
qkv = ( |
|
self.qkv(x_img) |
|
.reshape(B, N, 3, self.num_heads, C // self.num_heads) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x_img = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
x_img = self.proj(x_img) |
|
x_img = self.proj_drop(x_img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x_img |
|
|
|
|
|
class AttentionBlock(nn.Module): |
|
__constants__ = ["downsample"] |
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None): |
|
super(AttentionBlock, self).__init__() |
|
norm_layer = nn.BatchNorm2d |
|
self.conv1 = conv3x3(inplanes, planes, stride) |
|
self.bn1 = norm_layer(planes) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv2 = conv3x3(planes, planes) |
|
self.bn2 = norm_layer(planes) |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
self.inplanes = inplanes |
|
self.eca_block = eca_block() |
|
|
|
def forward(self, x): |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
inplanes = self.inplanes |
|
out = self.eca_block(out) |
|
if self.downsample is not None: |
|
identity = self.downsample(x) |
|
|
|
out += identity |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class Mlp(nn.Module): |
|
""" |
|
MLP as used in Vision Transformer, MLP-Mixer and related networks |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
drop=0.0, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
self.act = act_layer() |
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
self.drop = nn.Dropout(drop) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
in_chans, |
|
num_heads, |
|
mlp_ratio=4.0, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop_ratio=0.0, |
|
attn_drop_ratio=0.0, |
|
drop_path_ratio=0.0, |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
): |
|
super(Block, self).__init__() |
|
self.norm1 = norm_layer(dim) |
|
self.img_chanel = in_chans + 1 |
|
|
|
self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1) |
|
self.attn = Attention( |
|
dim, |
|
in_chans=in_chans, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop_ratio=attn_drop_ratio, |
|
proj_drop_ratio=drop_ratio, |
|
) |
|
|
|
self.drop_path = ( |
|
DropPath(drop_path_ratio) if drop_path_ratio > 0.0 else nn.Identity() |
|
) |
|
self.norm2 = norm_layer(dim) |
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
self.mlp = Mlp( |
|
in_features=dim, |
|
hidden_features=mlp_hidden_dim, |
|
act_layer=act_layer, |
|
drop=drop_ratio, |
|
) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
x_img = x |
|
|
|
|
|
x_img = x_img + self.drop_path(self.attn(self.norm1(x))) |
|
x = x_img + self.drop_path(self.mlp(self.norm2(x_img))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return x |
|
|
|
|
|
class ClassificationHead(nn.Module): |
|
def __init__(self, input_dim: int, target_dim: int): |
|
super().__init__() |
|
self.linear = torch.nn.Linear(input_dim, target_dim) |
|
|
|
def forward(self, x): |
|
x = x.view(x.size(0), -1) |
|
y_hat = self.linear(x) |
|
return y_hat |
|
|
|
|
|
def load_pretrained_weights(model, checkpoint): |
|
import collections |
|
|
|
if "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
else: |
|
state_dict = checkpoint |
|
model_dict = model.state_dict() |
|
new_state_dict = collections.OrderedDict() |
|
matched_layers, discarded_layers = [], [] |
|
for k, v in state_dict.items(): |
|
|
|
|
|
if k.startswith("module."): |
|
k = k[7:] |
|
if k in model_dict and model_dict[k].size() == v.size(): |
|
new_state_dict[k] = v |
|
matched_layers.append(k) |
|
else: |
|
discarded_layers.append(k) |
|
|
|
model_dict.update(new_state_dict) |
|
|
|
model.load_state_dict(model_dict) |
|
print("load_weight", len(matched_layers)) |
|
return model |
|
|
|
|
|
class eca_block(nn.Module): |
|
def __init__(self, channel=128, b=1, gamma=2): |
|
super(eca_block, self).__init__() |
|
kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) |
|
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 |
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.conv = nn.Conv1d( |
|
1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False |
|
) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
y = self.avg_pool(x) |
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) |
|
y = self.sigmoid(y) |
|
return x * y.expand_as(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class eca_block(nn.Module): |
|
def __init__(self, channel=196, b=1, gamma=2): |
|
super(eca_block, self).__init__() |
|
kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) |
|
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 |
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.conv = nn.Conv1d( |
|
1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False |
|
) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
y = self.avg_pool(x) |
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) |
|
y = self.sigmoid(y) |
|
return x * y.expand_as(x) |
|
|
|
|
|
class SE_block(nn.Module): |
|
def __init__(self, input_dim: int): |
|
super().__init__() |
|
self.linear1 = torch.nn.Linear(input_dim, input_dim) |
|
self.relu = nn.ReLU() |
|
self.linear2 = torch.nn.Linear(input_dim, input_dim) |
|
self.sigmod = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x1 = self.linear1(x) |
|
x1 = self.relu(x1) |
|
x1 = self.linear2(x1) |
|
x1 = self.sigmod(x1) |
|
x = x * x1 |
|
return x |
|
|
|
|
|
class VisionTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
img_size=14, |
|
patch_size=14, |
|
in_c=147, |
|
num_classes=7, |
|
embed_dim=768, |
|
depth=6, |
|
num_heads=8, |
|
mlp_ratio=4.0, |
|
qkv_bias=True, |
|
qk_scale=None, |
|
representation_size=None, |
|
distilled=False, |
|
drop_ratio=0.0, |
|
attn_drop_ratio=0.0, |
|
drop_path_ratio=0.0, |
|
embed_layer=PatchEmbed, |
|
norm_layer=None, |
|
act_layer=None, |
|
): |
|
""" |
|
Args: |
|
img_size (int, tuple): input image size |
|
patch_size (int, tuple): patch size |
|
in_c (int): number of input channels |
|
num_classes (int): number of classes for classification head |
|
embed_dim (int): embedding dimension |
|
depth (int): depth of transformer |
|
num_heads (int): number of attention heads |
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
|
qkv_bias (bool): enable bias for qkv if True |
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
|
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
|
distilled (bool): model includes a distillation token and head as in DeiT models |
|
drop_ratio (float): dropout rate |
|
attn_drop_ratio (float): attention dropout rate |
|
drop_path_ratio (float): stochastic depth rate |
|
embed_layer (nn.Module): patch embedding layer |
|
norm_layer: (nn.Module): normalization layer |
|
""" |
|
super(VisionTransformer, self).__init__() |
|
self.num_classes = num_classes |
|
self.num_features = self.embed_dim = ( |
|
embed_dim |
|
) |
|
self.num_tokens = 2 if distilled else 1 |
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
|
act_layer = act_layer or nn.GELU |
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.pos_embed = nn.Parameter(torch.zeros(1, in_c + 1, embed_dim)) |
|
self.pos_drop = nn.Dropout(p=drop_ratio) |
|
|
|
self.se_block = SE_block(input_dim=embed_dim) |
|
|
|
self.patch_embed = embed_layer( |
|
img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768 |
|
) |
|
num_patches = self.patch_embed.num_patches |
|
self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes) |
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.dist_token = ( |
|
nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None |
|
) |
|
|
|
self.pos_drop = nn.Dropout(p=drop_ratio) |
|
|
|
self.eca_block = eca_block() |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.CON1 = nn.Conv2d(256, 768, kernel_size=1, stride=1, bias=False) |
|
self.IRLinear1 = nn.Linear(1024, 768) |
|
self.IRLinear2 = nn.Linear(768, 512) |
|
self.eca_block = eca_block() |
|
dpr = [ |
|
x.item() for x in torch.linspace(0, drop_path_ratio, depth) |
|
] |
|
self.blocks = nn.Sequential( |
|
*[ |
|
Block( |
|
dim=embed_dim, |
|
in_chans=in_c, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
drop_ratio=drop_ratio, |
|
attn_drop_ratio=attn_drop_ratio, |
|
drop_path_ratio=dpr[i], |
|
norm_layer=norm_layer, |
|
act_layer=act_layer, |
|
) |
|
for i in range(depth) |
|
] |
|
) |
|
self.norm = norm_layer(embed_dim) |
|
|
|
|
|
if representation_size and not distilled: |
|
self.has_logits = True |
|
self.num_features = representation_size |
|
self.pre_logits = nn.Sequential( |
|
OrderedDict( |
|
[ |
|
("fc", nn.Linear(embed_dim, representation_size)), |
|
("act", nn.Tanh()), |
|
] |
|
) |
|
) |
|
else: |
|
self.has_logits = False |
|
self.pre_logits = nn.Identity() |
|
|
|
|
|
|
|
self.head_dist = None |
|
if distilled: |
|
self.head_dist = ( |
|
nn.Linear(self.embed_dim, self.num_classes) |
|
if num_classes > 0 |
|
else nn.Identity() |
|
) |
|
|
|
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
if self.dist_token is not None: |
|
nn.init.trunc_normal_(self.dist_token, std=0.02) |
|
|
|
nn.init.trunc_normal_(self.cls_token, std=0.02) |
|
self.apply(_init_vit_weights) |
|
|
|
def forward_features(self, x): |
|
|
|
|
|
|
|
|
|
|
|
cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
|
if self.dist_token is None: |
|
x = torch.cat((cls_token, x), dim=1) |
|
else: |
|
x = torch.cat( |
|
(cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1 |
|
) |
|
|
|
x = self.pos_drop(x + self.pos_embed) |
|
x = self.blocks(x) |
|
x = self.norm(x) |
|
if self.dist_token is None: |
|
return self.pre_logits(x[:, 0]) |
|
else: |
|
return x[:, 0], x[:, 1] |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.forward_features(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.se_block(x) |
|
|
|
x1 = self.head(x) |
|
|
|
return x1 |
|
|
|
|
|
def _init_vit_weights(m): |
|
""" |
|
ViT weight initialization |
|
:param m: module |
|
""" |
|
if isinstance(m, nn.Linear): |
|
nn.init.trunc_normal_(m.weight, std=0.01) |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode="fan_out") |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.zeros_(m.bias) |
|
nn.init.ones_(m.weight) |
|
|
|
|
|
def vit_base_patch16_224(num_classes: int = 7): |
|
""" |
|
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
|
weights ported from official Google JAX impl: |
|
链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=16, |
|
embed_dim=768, |
|
depth=12, |
|
num_heads=12, |
|
representation_size=None, |
|
num_classes=num_classes, |
|
) |
|
|
|
return model |
|
|
|
|
|
def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
|
""" |
|
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
|
weights ported from official Google JAX impl: |
|
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=16, |
|
embed_dim=768, |
|
depth=12, |
|
num_heads=12, |
|
representation_size=768 if has_logits else None, |
|
num_classes=num_classes, |
|
) |
|
return model |
|
|
|
|
|
def vit_base_patch32_224(num_classes: int = 1000): |
|
""" |
|
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
|
weights ported from official Google JAX impl: |
|
链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=32, |
|
embed_dim=768, |
|
depth=12, |
|
num_heads=12, |
|
representation_size=None, |
|
num_classes=num_classes, |
|
) |
|
return model |
|
|
|
|
|
def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
|
""" |
|
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
|
weights ported from official Google JAX impl: |
|
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=32, |
|
embed_dim=768, |
|
depth=12, |
|
num_heads=12, |
|
representation_size=768 if has_logits else None, |
|
num_classes=num_classes, |
|
) |
|
return model |
|
|
|
|
|
def vit_large_patch16_224(num_classes: int = 1000): |
|
""" |
|
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
|
weights ported from official Google JAX impl: |
|
链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8 |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=16, |
|
embed_dim=1024, |
|
depth=24, |
|
num_heads=16, |
|
representation_size=None, |
|
num_classes=num_classes, |
|
) |
|
return model |
|
|
|
|
|
def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
|
""" |
|
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
|
weights ported from official Google JAX impl: |
|
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=16, |
|
embed_dim=1024, |
|
depth=24, |
|
num_heads=16, |
|
representation_size=1024 if has_logits else None, |
|
num_classes=num_classes, |
|
) |
|
return model |
|
|
|
|
|
def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
|
""" |
|
ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
|
weights ported from official Google JAX impl: |
|
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=32, |
|
embed_dim=1024, |
|
depth=24, |
|
num_heads=16, |
|
representation_size=1024 if has_logits else None, |
|
num_classes=num_classes, |
|
) |
|
return model |
|
|
|
|
|
def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True): |
|
""" |
|
ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). |
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. |
|
NOTE: converted weights not currently available, too large for github release hosting. |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=14, |
|
embed_dim=1280, |
|
depth=32, |
|
num_heads=16, |
|
representation_size=1280 if has_logits else None, |
|
num_classes=num_classes, |
|
) |
|
return model |
|
|