import os import torch import torch.nn as nn from torch.nn import functional as F from .mobilefacenet import MobileFaceNet from .ir50 import Backbone from .vit_model import VisionTransformer, PatchEmbed from timm.layers import trunc_normal_, DropPath from thop import profile def load_pretrained_weights(model, checkpoint): import collections if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint model_dict = model.state_dict() new_state_dict = collections.OrderedDict() matched_layers, discarded_layers = [], [] for k, v in state_dict.items(): # If the pretrained state_dict was saved as nn.DataParallel, # keys would contain "module.", which should be ignored. if k.startswith("module."): k = k[7:] if k in model_dict and model_dict[k].size() == v.size(): new_state_dict[k] = v matched_layers.append(k) else: discarded_layers.append(k) # new_state_dict.requires_grad = False model_dict.update(new_state_dict) model.load_state_dict(model_dict) print("load_weight", len(matched_layers)) return model def window_partition(x, window_size, h_w, w_w): """ Args: x: (B, H, W, C) window_size: window size Returns: local window features (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, h_w, window_size, w_w, window_size, C) windows = ( x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) ) return windows class window(nn.Module): def __init__(self, window_size, dim): super(window, self).__init__() self.window_size = window_size self.norm = nn.LayerNorm(dim) def forward(self, x): x = x.permute(0, 2, 3, 1) B, H, W, C = x.shape x = self.norm(x) shortcut = x h_w = int(torch.div(H, self.window_size).item()) w_w = int(torch.div(W, self.window_size).item()) x_windows = window_partition(x, self.window_size, h_w, w_w) x_windows = x_windows.view(-1, self.window_size * self.window_size, C) return x_windows, shortcut class WindowAttentionGlobal(nn.Module): """ Global window attention based on: "Hatamizadeh et al., Global Context Vision Transformers " """ def __init__( self, dim, num_heads, window_size, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): """ Args: dim: feature size dimension. num_heads: number of attention head. window_size: window size. qkv_bias: bool argument for query, key, value learnable bias. qk_scale: bool argument to scaling query, key. attn_drop: attention dropout rate. proj_drop: output dropout rate. """ super().__init__() window_size = (window_size, window_size) self.window_size = window_size self.num_heads = num_heads head_dim = torch.div(dim, num_heads) self.scale = qk_scale or head_dim**-0.5 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=0.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, q_global): # print(f'q_global.shape:{q_global.shape}') # print(f'x.shape:{x.shape}') B_, N, C = x.shape B = q_global.shape[0] head_dim = int(torch.div(C, self.num_heads).item()) B_dim = int(torch.div(B_, B).item()) kv = ( self.qkv(x) .reshape(B_, N, 2, self.num_heads, head_dim) .permute(2, 0, 3, 1, 4) ) k, v = kv[0], kv[1] q_global = q_global.repeat(1, B_dim, 1, 1, 1) q = q_global.reshape(B_, self.num_heads, N, head_dim) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1, ) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def _to_channel_last(x): """ Args: x: (B, C, H, W) Returns: x: (B, H, W, C) """ return x.permute(0, 2, 3, 1) def _to_channel_first(x): return x.permute(0, 3, 1, 2) def _to_query(x, N, num_heads, dim_head): B = x.shape[0] x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4) return x class Mlp(nn.Module): """ Multi-Layer Perceptron (MLP) block """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): """ Args: in_features: input features dimension. hidden_features: hidden features dimension. out_features: output features dimension. act_layer: activation function. drop: dropout rate. """ super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_reverse(windows, window_size, H, W, h_w, w_w): """ Args: windows: local window features (num_windows*B, window_size, window_size, C) window_size: Window size H: Height of image W: Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, h_w, w_w, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class feedforward(nn.Module): def __init__( self, dim, window_size, mlp_ratio=4.0, act_layer=nn.GELU, drop=0.0, drop_path=0.0, layer_scale=None, ): super(feedforward, self).__init__() if layer_scale is not None and type(layer_scale) in [int, float]: self.layer_scale = True self.gamma1 = nn.Parameter( layer_scale * torch.ones(dim), requires_grad=True ) self.gamma2 = nn.Parameter( layer_scale * torch.ones(dim), requires_grad=True ) else: self.gamma1 = 1.0 self.gamma2 = 1.0 self.window_size = window_size self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, ) self.norm = nn.LayerNorm(dim) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, attn_windows, shortcut): B, H, W, C = shortcut.shape h_w = int(torch.div(H, self.window_size).item()) w_w = int(torch.div(W, self.window_size).item()) x = window_reverse(attn_windows, self.window_size, H, W, h_w, w_w) x = shortcut + self.drop_path(self.gamma1 * x) x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x))) return x class pyramid_trans_expr2(nn.Module): def __init__( self, img_size=224, num_classes=7, window_size=[28, 14, 7], num_heads=[2, 4, 8], dims=[64, 128, 256], embed_dim=768, ): super().__init__() self.img_size = img_size self.num_heads = num_heads self.dim_head = [] for num_head, dim in zip(num_heads, dims): self.dim_head.append(int(torch.div(dim, num_head).item())) self.num_classes = num_classes self.window_size = window_size self.N = [win * win for win in window_size] self.face_landback = MobileFaceNet([112, 112], 136) # Get the directory of the current file (models/PosterV2_7cls.py) script_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the full path to the model file mobilefacenet_path = os.path.join( script_dir, "pretrain", "mobilefacenet_model_best.pth.tar" ) ir50_path = os.path.join(script_dir, r"pretrain\ir50.pth") print(mobilefacenet_path) face_landback_checkpoint = torch.load( mobilefacenet_path, map_location=lambda storage, loc: storage, weights_only=False, ) self.face_landback.load_state_dict(face_landback_checkpoint["state_dict"]) for param in self.face_landback.parameters(): param.requires_grad = False self.VIT = VisionTransformer(depth=2, embed_dim=embed_dim) self.ir_back = Backbone(50, 0.0, "ir") ir_checkpoint = torch.load( ir50_path, map_location=lambda storage, loc: storage, weights_only=False ) self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint) self.attn1 = WindowAttentionGlobal( dim=dims[0], num_heads=num_heads[0], window_size=window_size[0] ) self.attn2 = WindowAttentionGlobal( dim=dims[1], num_heads=num_heads[1], window_size=window_size[1] ) self.attn3 = WindowAttentionGlobal( dim=dims[2], num_heads=num_heads[2], window_size=window_size[2] ) self.window1 = window(window_size=window_size[0], dim=dims[0]) self.window2 = window(window_size=window_size[1], dim=dims[1]) self.window3 = window(window_size=window_size[2], dim=dims[2]) self.conv1 = nn.Conv2d( in_channels=dims[0], out_channels=dims[0], kernel_size=3, stride=2, padding=1, ) self.conv2 = nn.Conv2d( in_channels=dims[1], out_channels=dims[1], kernel_size=3, stride=2, padding=1, ) self.conv3 = nn.Conv2d( in_channels=dims[2], out_channels=dims[2], kernel_size=3, stride=2, padding=1, ) dpr = [x.item() for x in torch.linspace(0, 0.5, 5)] self.ffn1 = feedforward( dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0] ) self.ffn2 = feedforward( dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1] ) self.ffn3 = feedforward( dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2] ) self.last_face_conv = nn.Conv2d( in_channels=512, out_channels=256, kernel_size=3, padding=1 ) self.embed_q = nn.Sequential( nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1), nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1), ) self.embed_k = nn.Sequential( nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1) ) self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768) def forward(self, x): x_face = F.interpolate(x, size=112) x_face1, x_face2, x_face3 = self.face_landback(x_face) x_face3 = self.last_face_conv(x_face3) x_face1, x_face2, x_face3 = ( _to_channel_last(x_face1), _to_channel_last(x_face2), _to_channel_last(x_face3), ) q1, q2, q3 = ( _to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]), _to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]), _to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2]), ) x_ir1, x_ir2, x_ir3 = self.ir_back(x) x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3) x_window1, shortcut1 = self.window1(x_ir1) x_window2, shortcut2 = self.window2(x_ir2) x_window3, shortcut3 = self.window3(x_ir3) o1, o2, o3 = ( self.attn1(x_window1, q1), self.attn2(x_window2, q2), self.attn3(x_window3, q3), ) o1, o2, o3 = ( self.ffn1(o1, shortcut1), self.ffn2(o2, shortcut2), self.ffn3(o3, shortcut3), ) o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3) o1, o2, o3 = ( self.embed_q(o1).flatten(2).transpose(1, 2), self.embed_k(o2).flatten(2).transpose(1, 2), self.embed_v(o3), ) o = torch.cat([o1, o2, o3], dim=1) out = self.VIT(o) return out def compute_param_flop(): model = pyramid_trans_expr2() img = torch.rand(size=(1, 3, 224, 224)) flops, params = profile(model, inputs=(img,)) print(f"flops:{flops/1000**3}G,params:{params/1000**2}M")