TADBot / FER /models /PosterV2_7cls.py
ryefoxlime's picture
FER alpha 0.1
499f0dc
raw
history blame
14.6 kB
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from .mobilefacenet import MobileFaceNet
from .ir50 import Backbone
from .vit_model import VisionTransformer, PatchEmbed
from timm.layers import trunc_normal_, DropPath
from thop import profile
def load_pretrained_weights(model, checkpoint):
import collections
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = collections.OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
# If the pretrained state_dict was saved as nn.DataParallel,
# keys would contain "module.", which should be ignored.
if k.startswith("module."):
k = k[7:]
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
# new_state_dict.requires_grad = False
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
print("load_weight", len(matched_layers))
return model
def window_partition(x, window_size, h_w, w_w):
"""
Args:
x: (B, H, W, C)
window_size: window size
Returns:
local window features (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, h_w, window_size, w_w, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows
class window(nn.Module):
def __init__(self, window_size, dim):
super(window, self).__init__()
self.window_size = window_size
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
B, H, W, C = x.shape
x = self.norm(x)
shortcut = x
h_w = int(torch.div(H, self.window_size).item())
w_w = int(torch.div(W, self.window_size).item())
x_windows = window_partition(x, self.window_size, h_w, w_w)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
return x_windows, shortcut
class WindowAttentionGlobal(nn.Module):
"""
Global window attention based on: "Hatamizadeh et al.,
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
"""
def __init__(
self,
dim,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
"""
Args:
dim: feature size dimension.
num_heads: number of attention head.
window_size: window size.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
attn_drop: attention dropout rate.
proj_drop: output dropout rate.
"""
super().__init__()
window_size = (window_size, window_size)
self.window_size = window_size
self.num_heads = num_heads
head_dim = torch.div(dim, num_heads)
self.scale = qk_scale or head_dim**-0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, q_global):
# print(f'q_global.shape:{q_global.shape}')
# print(f'x.shape:{x.shape}')
B_, N, C = x.shape
B = q_global.shape[0]
head_dim = int(torch.div(C, self.num_heads).item())
B_dim = int(torch.div(B_, B).item())
kv = (
self.qkv(x)
.reshape(B_, N, 2, self.num_heads, head_dim)
.permute(2, 0, 3, 1, 4)
)
k, v = kv[0], kv[1]
q_global = q_global.repeat(1, B_dim, 1, 1, 1)
q = q_global.reshape(B_, self.num_heads, N, head_dim)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _to_channel_last(x):
"""
Args:
x: (B, C, H, W)
Returns:
x: (B, H, W, C)
"""
return x.permute(0, 2, 3, 1)
def _to_channel_first(x):
return x.permute(0, 3, 1, 2)
def _to_query(x, N, num_heads, dim_head):
B = x.shape[0]
x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4)
return x
class Mlp(nn.Module):
"""
Multi-Layer Perceptron (MLP) block
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
"""
Args:
in_features: input features dimension.
hidden_features: hidden features dimension.
out_features: output features dimension.
act_layer: activation function.
drop: dropout rate.
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_reverse(windows, window_size, H, W, h_w, w_w):
"""
Args:
windows: local window features (num_windows*B, window_size, window_size, C)
window_size: Window size
H: Height of image
W: Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, h_w, w_w, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class feedforward(nn.Module):
def __init__(
self,
dim,
window_size,
mlp_ratio=4.0,
act_layer=nn.GELU,
drop=0.0,
drop_path=0.0,
layer_scale=None,
):
super(feedforward, self).__init__()
if layer_scale is not None and type(layer_scale) in [int, float]:
self.layer_scale = True
self.gamma1 = nn.Parameter(
layer_scale * torch.ones(dim), requires_grad=True
)
self.gamma2 = nn.Parameter(
layer_scale * torch.ones(dim), requires_grad=True
)
else:
self.gamma1 = 1.0
self.gamma2 = 1.0
self.window_size = window_size
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm = nn.LayerNorm(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, attn_windows, shortcut):
B, H, W, C = shortcut.shape
h_w = int(torch.div(H, self.window_size).item())
w_w = int(torch.div(W, self.window_size).item())
x = window_reverse(attn_windows, self.window_size, H, W, h_w, w_w)
x = shortcut + self.drop_path(self.gamma1 * x)
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x)))
return x
class pyramid_trans_expr2(nn.Module):
def __init__(
self,
img_size=224,
num_classes=7,
window_size=[28, 14, 7],
num_heads=[2, 4, 8],
dims=[64, 128, 256],
embed_dim=768,
):
super().__init__()
self.img_size = img_size
self.num_heads = num_heads
self.dim_head = []
for num_head, dim in zip(num_heads, dims):
self.dim_head.append(int(torch.div(dim, num_head).item()))
self.num_classes = num_classes
self.window_size = window_size
self.N = [win * win for win in window_size]
self.face_landback = MobileFaceNet([112, 112], 136)
# Get the directory of the current file (models/PosterV2_7cls.py)
script_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the full path to the model file
mobilefacenet_path = os.path.join(
script_dir, "pretrain", "mobilefacenet_model_best.pth.tar"
)
ir50_path = os.path.join(script_dir, r"pretrain\ir50.pth")
print(mobilefacenet_path)
face_landback_checkpoint = torch.load(
mobilefacenet_path,
map_location=lambda storage, loc: storage,
weights_only=False,
)
self.face_landback.load_state_dict(face_landback_checkpoint["state_dict"])
for param in self.face_landback.parameters():
param.requires_grad = False
self.VIT = VisionTransformer(depth=2, embed_dim=embed_dim)
self.ir_back = Backbone(50, 0.0, "ir")
ir_checkpoint = torch.load(
ir50_path, map_location=lambda storage, loc: storage, weights_only=False
)
self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
self.attn1 = WindowAttentionGlobal(
dim=dims[0], num_heads=num_heads[0], window_size=window_size[0]
)
self.attn2 = WindowAttentionGlobal(
dim=dims[1], num_heads=num_heads[1], window_size=window_size[1]
)
self.attn3 = WindowAttentionGlobal(
dim=dims[2], num_heads=num_heads[2], window_size=window_size[2]
)
self.window1 = window(window_size=window_size[0], dim=dims[0])
self.window2 = window(window_size=window_size[1], dim=dims[1])
self.window3 = window(window_size=window_size[2], dim=dims[2])
self.conv1 = nn.Conv2d(
in_channels=dims[0],
out_channels=dims[0],
kernel_size=3,
stride=2,
padding=1,
)
self.conv2 = nn.Conv2d(
in_channels=dims[1],
out_channels=dims[1],
kernel_size=3,
stride=2,
padding=1,
)
self.conv3 = nn.Conv2d(
in_channels=dims[2],
out_channels=dims[2],
kernel_size=3,
stride=2,
padding=1,
)
dpr = [x.item() for x in torch.linspace(0, 0.5, 5)]
self.ffn1 = feedforward(
dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0]
)
self.ffn2 = feedforward(
dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1]
)
self.ffn3 = feedforward(
dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2]
)
self.last_face_conv = nn.Conv2d(
in_channels=512, out_channels=256, kernel_size=3, padding=1
)
self.embed_q = nn.Sequential(
nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1),
nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1),
)
self.embed_k = nn.Sequential(
nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1)
)
self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768)
def forward(self, x):
x_face = F.interpolate(x, size=112)
x_face1, x_face2, x_face3 = self.face_landback(x_face)
x_face3 = self.last_face_conv(x_face3)
x_face1, x_face2, x_face3 = (
_to_channel_last(x_face1),
_to_channel_last(x_face2),
_to_channel_last(x_face3),
)
q1, q2, q3 = (
_to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]),
_to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]),
_to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2]),
)
x_ir1, x_ir2, x_ir3 = self.ir_back(x)
x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3)
x_window1, shortcut1 = self.window1(x_ir1)
x_window2, shortcut2 = self.window2(x_ir2)
x_window3, shortcut3 = self.window3(x_ir3)
o1, o2, o3 = (
self.attn1(x_window1, q1),
self.attn2(x_window2, q2),
self.attn3(x_window3, q3),
)
o1, o2, o3 = (
self.ffn1(o1, shortcut1),
self.ffn2(o2, shortcut2),
self.ffn3(o3, shortcut3),
)
o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3)
o1, o2, o3 = (
self.embed_q(o1).flatten(2).transpose(1, 2),
self.embed_k(o2).flatten(2).transpose(1, 2),
self.embed_v(o3),
)
o = torch.cat([o1, o2, o3], dim=1)
out = self.VIT(o)
return out
def compute_param_flop():
model = pyramid_trans_expr2()
img = torch.rand(size=(1, 3, 224, 224))
flops, params = profile(model, inputs=(img,))
print(f"flops:{flops/1000**3}G,params:{params/1000**2}M")