ryefoxlime commited on
Commit
499f0dc
1 Parent(s): 49cca10

FER alpha 0.1

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  FER/models/__pycache__
2
  FER/__pycache__
3
  .env
4
- .venv
 
 
1
  FER/models/__pycache__
2
  FER/__pycache__
3
  .env
4
+ .venv
5
+ FER/Images/
FER/data_preprocessing/__pycache__/sam.cpython-311.pyc ADDED
Binary file (4.7 kB). View file
 
FER/data_preprocessing/sam.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class SAM(torch.optim.Optimizer):
5
+ def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
6
+ assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
7
+
8
+ defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
9
+ super(SAM, self).__init__(params, defaults)
10
+
11
+ self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
12
+ self.param_groups = self.base_optimizer.param_groups
13
+
14
+ @torch.no_grad()
15
+ def first_step(self, zero_grad=False):
16
+ grad_norm = self._grad_norm()
17
+ for group in self.param_groups:
18
+ scale = group["rho"] / (grad_norm + 1e-12)
19
+
20
+ for p in group["params"]:
21
+ if p.grad is None: continue
22
+ self.state[p]["old_p"] = p.data.clone()
23
+ e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
24
+ p.add_(e_w) # climb to the local maximum "w + e(w)"
25
+
26
+ if zero_grad: self.zero_grad()
27
+
28
+ @torch.no_grad()
29
+ def second_step(self, zero_grad=False):
30
+ for group in self.param_groups:
31
+ for p in group["params"]:
32
+ if p.grad is None: continue
33
+ p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
34
+
35
+ self.base_optimizer.step() # do the actual "sharpness-aware" update
36
+
37
+ if zero_grad: self.zero_grad()
38
+
39
+ @torch.no_grad()
40
+ def step(self, closure=None):
41
+ assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
42
+ closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
43
+
44
+ self.first_step(zero_grad=True)
45
+ closure()
46
+ self.second_step()
47
+
48
+ def _grad_norm(self):
49
+ shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
50
+ norm = torch.norm(
51
+ torch.stack([
52
+ ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
53
+ for group in self.param_groups for p in group["params"]
54
+ if p.grad is not None
55
+ ]),
56
+ p=2
57
+ )
58
+ return norm
59
+
60
+ def load_state_dict(self, state_dict):
61
+ super().load_state_dict(state_dict)
62
+ self.base_optimizer.param_groups = self.param_groups
63
+
FER/detectfaces.py CHANGED
@@ -4,13 +4,10 @@ import torch
4
  import os
5
  import time
6
  from PIL import Image
 
7
 
8
  # Define the path to the model checkpoint
9
- # Get the directory of the current file (models/PosterV2_7cls.py)
10
- script_dir = os.path.dirname(os.path.abspath(__file__))
11
-
12
- # Construct the full path to the model file
13
- model_path = os.path.join(script_dir, r"models\checkpoints\raf-db-model_best.pth")
14
 
15
  # Determine the available device for model execution
16
  if torch.backends.mps.is_available():
@@ -39,7 +36,7 @@ def main():
39
  if model_path is not None:
40
  if os.path.isfile(model_path):
41
  print("=> loading checkpoint '{}'".format(model_path))
42
- checkpoint = torch.load(model_path, map_location=device)
43
  best_acc = checkpoint["best_acc"]
44
  best_acc = best_acc.to()
45
  print(f"best_acc:{best_acc}")
@@ -50,7 +47,9 @@ def main():
50
  )
51
  )
52
  else:
53
- print("=> no checkpoint found at '{}'".format(model_path))
 
 
54
  # Start webcam capture and prediction
55
  imagecapture(model)
56
  return
@@ -92,7 +91,7 @@ def imagecapture(model):
92
 
93
  # If faces are detected, proceed with prediction
94
  if len(faces) > 0:
95
- currtimeimg = time.strftime("%H:%M:%S:%.*f")
96
  print(f"[!]Face detected at {currtimeimg}")
97
  # Crop the face region
98
  face_region = frame[
@@ -105,12 +104,12 @@ def imagecapture(model):
105
  )
106
  print("[!]Start Expressions")
107
  # Record the prediction start time
108
- starttime = time.strftime("%H:%M:%S:%.*f")
109
  print(f"-->Prediction starting at {starttime}")
110
  # Perform emotion prediction
111
  predict(model, image_path=face_pil_image)
112
  # Record the prediction end time
113
- endtime = time.strftime("%H:%M:%S:%.*f")
114
  print(f"-->Done prediction at {endtime}")
115
 
116
  # Stop capturing once prediction is complete
 
4
  import os
5
  import time
6
  from PIL import Image
7
+ from main import RecorderMeter1, RecorderMeter
8
 
9
  # Define the path to the model checkpoint
10
+ model_path = os.path.abspath(r"FER\models\checkpoints\raf-db-model_best.pth")
 
 
 
 
11
 
12
  # Determine the available device for model execution
13
  if torch.backends.mps.is_available():
 
36
  if model_path is not None:
37
  if os.path.isfile(model_path):
38
  print("=> loading checkpoint '{}'".format(model_path))
39
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
40
  best_acc = checkpoint["best_acc"]
41
  best_acc = best_acc.to()
42
  print(f"best_acc:{best_acc}")
 
47
  )
48
  )
49
  else:
50
+ print(
51
+ "[!] detectfaces.py => no checkpoint found at '{}'".format(model_path)
52
+ )
53
  # Start webcam capture and prediction
54
  imagecapture(model)
55
  return
 
91
 
92
  # If faces are detected, proceed with prediction
93
  if len(faces) > 0:
94
+ currtimeimg = time.strftime("%H:%M:%S")
95
  print(f"[!]Face detected at {currtimeimg}")
96
  # Crop the face region
97
  face_region = frame[
 
104
  )
105
  print("[!]Start Expressions")
106
  # Record the prediction start time
107
+ starttime = time.strftime("%H:%M:%S")
108
  print(f"-->Prediction starting at {starttime}")
109
  # Perform emotion prediction
110
  predict(model, image_path=face_pil_image)
111
  # Record the prediction end time
112
+ endtime = time.strftime("%H:%M:%S")
113
  print(f"-->Done prediction at {endtime}")
114
 
115
  # Stop capturing once prediction is complete
FER/main.py CHANGED
@@ -21,7 +21,7 @@ import torchvision.transforms as transforms
21
  import numpy as np
22
  import datetime
23
  from torchsampler import ImbalancedDatasetSampler
24
- from models.PosterV2_7cls import *
25
 
26
 
27
  warnings.filterwarnings("ignore", category=UserWarning)
 
21
  import numpy as np
22
  import datetime
23
  from torchsampler import ImbalancedDatasetSampler
24
+ from models.PosterV2_7cls import pyramid_trans_expr2
25
 
26
 
27
  warnings.filterwarnings("ignore", category=UserWarning)
FER/models/PosterV2_7cls.py CHANGED
@@ -5,7 +5,7 @@ from torch.nn import functional as F
5
  from .mobilefacenet import MobileFaceNet
6
  from .ir50 import Backbone
7
  from .vit_model import VisionTransformer, PatchEmbed
8
- from timm.models.layers import trunc_normal_, DropPath
9
  from thop import profile
10
 
11
 
@@ -315,6 +315,7 @@ class pyramid_trans_expr2(nn.Module):
315
  face_landback_checkpoint = torch.load(
316
  mobilefacenet_path,
317
  map_location=lambda storage, loc: storage,
 
318
  )
319
  self.face_landback.load_state_dict(face_landback_checkpoint["state_dict"])
320
 
@@ -325,8 +326,7 @@ class pyramid_trans_expr2(nn.Module):
325
 
326
  self.ir_back = Backbone(50, 0.0, "ir")
327
  ir_checkpoint = torch.load(
328
- ir50_path,
329
- map_location=lambda storage, loc: storage,
330
  )
331
 
332
  self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
 
5
  from .mobilefacenet import MobileFaceNet
6
  from .ir50 import Backbone
7
  from .vit_model import VisionTransformer, PatchEmbed
8
+ from timm.layers import trunc_normal_, DropPath
9
  from thop import profile
10
 
11
 
 
315
  face_landback_checkpoint = torch.load(
316
  mobilefacenet_path,
317
  map_location=lambda storage, loc: storage,
318
+ weights_only=False,
319
  )
320
  self.face_landback.load_state_dict(face_landback_checkpoint["state_dict"])
321
 
 
326
 
327
  self.ir_back = Backbone(50, 0.0, "ir")
328
  ir_checkpoint = torch.load(
329
+ ir50_path, map_location=lambda storage, loc: storage, weights_only=False
 
330
  )
331
 
332
  self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
FER/models/PosterV2_8cls.py CHANGED
@@ -4,13 +4,15 @@ from torch.nn import functional as F
4
  from .mobilefacenet import MobileFaceNet
5
  from .ir50 import Backbone
6
  from .vit_model_8 import VisionTransformer, PatchEmbed
7
- from timm.models.layers import trunc_normal_, DropPath
8
  from thop import profile
9
 
 
10
  def load_pretrained_weights(model, checkpoint):
11
  import collections
12
- if 'state_dict' in checkpoint:
13
- state_dict = checkpoint['state_dict']
 
14
  else:
15
  state_dict = checkpoint
16
  model_dict = model.state_dict()
@@ -19,7 +21,7 @@ def load_pretrained_weights(model, checkpoint):
19
  for k, v in state_dict.items():
20
  # If the pretrained state_dict was saved as nn.DataParallel,
21
  # keys would contain "module.", which should be ignored.
22
- if k.startswith('module.'):
23
  k = k[7:]
24
  if k in model_dict and model_dict[k].size() == v.size():
25
  new_state_dict[k] = v
@@ -30,9 +32,10 @@ def load_pretrained_weights(model, checkpoint):
30
  model_dict.update(new_state_dict)
31
 
32
  model.load_state_dict(model_dict)
33
- print('load_weight', len(matched_layers))
34
  return model
35
 
 
36
  def window_partition(x, window_size, h_w, w_w):
37
  """
38
  Args:
@@ -44,14 +47,18 @@ def window_partition(x, window_size, h_w, w_w):
44
  """
45
  B, H, W, C = x.shape
46
  x = x.view(B, h_w, window_size, w_w, window_size, C)
47
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
 
 
48
  return windows
49
 
 
50
  class window(nn.Module):
51
  def __init__(self, window_size, dim):
52
  super(window, self).__init__()
53
  self.window_size = window_size
54
  self.norm = nn.LayerNorm(dim)
 
55
  def forward(self, x):
56
  x = x.permute(0, 2, 3, 1)
57
  B, H, W, C = x.shape
@@ -63,21 +70,23 @@ class window(nn.Module):
63
  x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
64
  return x_windows, shortcut
65
 
 
66
  class WindowAttentionGlobal(nn.Module):
67
  """
68
  Global window attention based on: "Hatamizadeh et al.,
69
  Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
70
  """
71
 
72
- def __init__(self,
73
- dim,
74
- num_heads,
75
- window_size,
76
- qkv_bias=True,
77
- qk_scale=None,
78
- attn_drop=0.,
79
- proj_drop=0.,
80
- ):
 
81
  """
82
  Args:
83
  dim: feature size dimension.
@@ -94,9 +103,10 @@ class WindowAttentionGlobal(nn.Module):
94
  self.window_size = window_size
95
  self.num_heads = num_heads
96
  head_dim = torch.div(dim, num_heads)
97
- self.scale = qk_scale or head_dim ** -0.5
98
  self.relative_position_bias_table = nn.Parameter(
99
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
 
100
  coords_h = torch.arange(self.window_size[0])
101
  coords_w = torch.arange(self.window_size[1])
102
  coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
@@ -112,7 +122,7 @@ class WindowAttentionGlobal(nn.Module):
112
  self.attn_drop = nn.Dropout(attn_drop)
113
  self.proj = nn.Linear(dim, dim)
114
  self.proj_drop = nn.Dropout(proj_drop)
115
- trunc_normal_(self.relative_position_bias_table, std=.02)
116
  self.softmax = nn.Softmax(dim=-1)
117
 
118
  def forward(self, x, q_global):
@@ -122,14 +132,23 @@ class WindowAttentionGlobal(nn.Module):
122
  B = q_global.shape[0]
123
  head_dim = int(torch.div(C, self.num_heads).item())
124
  B_dim = int(torch.div(B_, B).item())
125
- kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, head_dim).permute(2, 0, 3, 1, 4)
 
 
 
 
126
  k, v = kv[0], kv[1]
127
  q_global = q_global.repeat(1, B_dim, 1, 1, 1)
128
  q = q_global.reshape(B_, self.num_heads, N, head_dim)
129
  q = q * self.scale
130
- attn = (q @ k.transpose(-2, -1))
131
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
132
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
 
 
 
 
 
133
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
134
  attn = attn + relative_position_bias.unsqueeze(0)
135
  attn = self.softmax(attn)
@@ -139,6 +158,7 @@ class WindowAttentionGlobal(nn.Module):
139
  x = self.proj_drop(x)
140
  return x
141
 
 
142
  def _to_channel_last(x):
143
  """
144
  Args:
@@ -149,25 +169,30 @@ def _to_channel_last(x):
149
  """
150
  return x.permute(0, 2, 3, 1)
151
 
 
152
  def _to_channel_first(x):
153
  return x.permute(0, 3, 1, 2)
154
 
 
155
  def _to_query(x, N, num_heads, dim_head):
156
  B = x.shape[0]
157
  x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4)
158
  return x
159
 
 
160
  class Mlp(nn.Module):
161
  """
162
  Multi-Layer Perceptron (MLP) block
163
  """
164
 
165
- def __init__(self,
166
- in_features,
167
- hidden_features=None,
168
- out_features=None,
169
- act_layer=nn.GELU,
170
- drop=0.):
 
 
171
  """
172
  Args:
173
  in_features: input features dimension.
@@ -193,6 +218,7 @@ class Mlp(nn.Module):
193
  x = self.drop(x)
194
  return x
195
 
 
196
  def window_reverse(windows, window_size, H, W, h_w, w_w):
197
  """
198
  Args:
@@ -209,20 +235,40 @@ def window_reverse(windows, window_size, H, W, h_w, w_w):
209
  x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
210
  return x
211
 
 
212
  class feedforward(nn.Module):
213
- def __init__(self, dim, window_size, mlp_ratio=4., act_layer=nn.GELU, drop=0., drop_path=0., layer_scale=None):
 
 
 
 
 
 
 
 
 
214
  super(feedforward, self).__init__()
215
  if layer_scale is not None and type(layer_scale) in [int, float]:
216
  self.layer_scale = True
217
- self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
218
- self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
 
 
 
 
219
  else:
220
  self.gamma1 = 1.0
221
  self.gamma2 = 1.0
222
  self.window_size = window_size
223
- self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
 
 
 
 
 
224
  self.norm = nn.LayerNorm(dim)
225
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
 
226
  def forward(self, attn_windows, shortcut):
227
  B, H, W, C = shortcut.shape
228
  h_w = int(torch.div(H, self.window_size).item())
@@ -232,8 +278,17 @@ class feedforward(nn.Module):
232
  x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x)))
233
  return x
234
 
 
235
  class pyramid_trans_expr2(nn.Module):
236
- def __init__(self, img_size=224, num_classes=8, window_size=[28,14,7], num_heads=[2, 4, 8], dims=[64, 128, 256], embed_dim=768):
 
 
 
 
 
 
 
 
237
  super().__init__()
238
 
239
  self.img_size = img_size
@@ -245,51 +300,99 @@ class pyramid_trans_expr2(nn.Module):
245
  self.window_size = window_size
246
  self.N = [win * win for win in window_size]
247
  self.face_landback = MobileFaceNet([112, 112], 136)
248
- face_landback_checkpoint = torch.load(r'./pretrain/mobilefacenet_model_best.pth.tar',
249
- map_location=lambda storage, loc: storage)
250
- self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
 
 
251
 
252
  for param in self.face_landback.parameters():
253
  param.requires_grad = False
254
 
255
- self.VIT = VisionTransformer(depth=2, embed_dim=embed_dim, num_classes=num_classes)
 
 
256
 
257
- self.ir_back = Backbone(50, 0.0, 'ir')
258
- ir_checkpoint = torch.load(r'./pretrain/ir50.pth', map_location=lambda storage, loc: storage)
 
 
259
 
260
  self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
261
 
262
- self.attn1 = WindowAttentionGlobal(dim=dims[0], num_heads=num_heads[0], window_size=window_size[0])
263
- self.attn2 = WindowAttentionGlobal(dim=dims[1], num_heads=num_heads[1], window_size=window_size[1])
264
- self.attn3 = WindowAttentionGlobal(dim=dims[2], num_heads=num_heads[2], window_size=window_size[2])
 
 
 
 
 
 
265
  self.window1 = window(window_size=window_size[0], dim=dims[0])
266
  self.window2 = window(window_size=window_size[1], dim=dims[1])
267
  self.window3 = window(window_size=window_size[2], dim=dims[2])
268
- self.conv1 = nn.Conv2d(in_channels=dims[0], out_channels=dims[0], kernel_size=3, stride=2, padding=1)
269
- self.conv2 = nn.Conv2d(in_channels=dims[1], out_channels=dims[1], kernel_size=3, stride=2, padding=1)
270
- self.conv3 = nn.Conv2d(in_channels=dims[2], out_channels=dims[2], kernel_size=3, stride=2, padding=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  dpr = [x.item() for x in torch.linspace(0, 0.5, 5)]
273
- self.ffn1 = feedforward(dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0])
274
- self.ffn2 = feedforward(dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1])
275
- self.ffn3 = feedforward(dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2])
276
-
277
- self.last_face_conv = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
278
-
279
- self.embed_q = nn.Sequential(nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1),
280
- nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1))
281
- self.embed_k = nn.Sequential(nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1))
 
 
 
 
 
 
 
 
 
 
 
 
282
  self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768)
283
 
284
  def forward(self, x):
285
  x_face = F.interpolate(x, size=112)
286
- x_face1 , x_face2, x_face3 = self.face_landback(x_face)
287
  x_face3 = self.last_face_conv(x_face3)
288
- x_face1, x_face2, x_face3 = _to_channel_last(x_face1), _to_channel_last(x_face2), _to_channel_last(x_face3)
289
-
290
- q1, q2, q3 = _to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]), \
291
- _to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]), \
292
- _to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2])
 
 
 
 
 
 
293
 
294
  x_ir1, x_ir2, x_ir3 = self.ir_back(x)
295
  x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3)
@@ -297,21 +400,34 @@ class pyramid_trans_expr2(nn.Module):
297
  x_window2, shortcut2 = self.window2(x_ir2)
298
  x_window3, shortcut3 = self.window3(x_ir3)
299
 
300
- o1, o2, o3 = self.attn1(x_window1, q1), self.attn2(x_window2, q2), self.attn3(x_window3, q3)
 
 
 
 
301
 
302
- o1, o2, o3 = self.ffn1(o1, shortcut1), self.ffn2(o2, shortcut2), self.ffn3(o3, shortcut3)
 
 
 
 
303
 
304
  o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3)
305
 
306
- o1, o2, o3 = self.embed_q(o1).flatten(2).transpose(1, 2), self.embed_k(o2).flatten(2).transpose(1, 2), self.embed_v(o3)
 
 
 
 
307
 
308
  o = torch.cat([o1, o2, o3], dim=1)
309
 
310
  out = self.VIT(o)
311
  return out
312
 
 
313
  def compute_param_flop():
314
  model = pyramid_trans_expr2()
315
- img = torch.rand(size=(1,3,224,224))
316
  flops, params = profile(model, inputs=(img,))
317
- print(f'flops:{flops/1000**3}G,params:{params/1000**2}M')
 
4
  from .mobilefacenet import MobileFaceNet
5
  from .ir50 import Backbone
6
  from .vit_model_8 import VisionTransformer, PatchEmbed
7
+ from timm.layers import trunc_normal_, DropPath
8
  from thop import profile
9
 
10
+
11
  def load_pretrained_weights(model, checkpoint):
12
  import collections
13
+
14
+ if "state_dict" in checkpoint:
15
+ state_dict = checkpoint["state_dict"]
16
  else:
17
  state_dict = checkpoint
18
  model_dict = model.state_dict()
 
21
  for k, v in state_dict.items():
22
  # If the pretrained state_dict was saved as nn.DataParallel,
23
  # keys would contain "module.", which should be ignored.
24
+ if k.startswith("module."):
25
  k = k[7:]
26
  if k in model_dict and model_dict[k].size() == v.size():
27
  new_state_dict[k] = v
 
32
  model_dict.update(new_state_dict)
33
 
34
  model.load_state_dict(model_dict)
35
+ print("load_weight", len(matched_layers))
36
  return model
37
 
38
+
39
  def window_partition(x, window_size, h_w, w_w):
40
  """
41
  Args:
 
47
  """
48
  B, H, W, C = x.shape
49
  x = x.view(B, h_w, window_size, w_w, window_size, C)
50
+ windows = (
51
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
52
+ )
53
  return windows
54
 
55
+
56
  class window(nn.Module):
57
  def __init__(self, window_size, dim):
58
  super(window, self).__init__()
59
  self.window_size = window_size
60
  self.norm = nn.LayerNorm(dim)
61
+
62
  def forward(self, x):
63
  x = x.permute(0, 2, 3, 1)
64
  B, H, W, C = x.shape
 
70
  x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
71
  return x_windows, shortcut
72
 
73
+
74
  class WindowAttentionGlobal(nn.Module):
75
  """
76
  Global window attention based on: "Hatamizadeh et al.,
77
  Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
78
  """
79
 
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ num_heads,
84
+ window_size,
85
+ qkv_bias=True,
86
+ qk_scale=None,
87
+ attn_drop=0.0,
88
+ proj_drop=0.0,
89
+ ):
90
  """
91
  Args:
92
  dim: feature size dimension.
 
103
  self.window_size = window_size
104
  self.num_heads = num_heads
105
  head_dim = torch.div(dim, num_heads)
106
+ self.scale = qk_scale or head_dim**-0.5
107
  self.relative_position_bias_table = nn.Parameter(
108
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
109
+ )
110
  coords_h = torch.arange(self.window_size[0])
111
  coords_w = torch.arange(self.window_size[1])
112
  coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
 
122
  self.attn_drop = nn.Dropout(attn_drop)
123
  self.proj = nn.Linear(dim, dim)
124
  self.proj_drop = nn.Dropout(proj_drop)
125
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
126
  self.softmax = nn.Softmax(dim=-1)
127
 
128
  def forward(self, x, q_global):
 
132
  B = q_global.shape[0]
133
  head_dim = int(torch.div(C, self.num_heads).item())
134
  B_dim = int(torch.div(B_, B).item())
135
+ kv = (
136
+ self.qkv(x)
137
+ .reshape(B_, N, 2, self.num_heads, head_dim)
138
+ .permute(2, 0, 3, 1, 4)
139
+ )
140
  k, v = kv[0], kv[1]
141
  q_global = q_global.repeat(1, B_dim, 1, 1, 1)
142
  q = q_global.reshape(B_, self.num_heads, N, head_dim)
143
  q = q * self.scale
144
+ attn = q @ k.transpose(-2, -1)
145
+ relative_position_bias = self.relative_position_bias_table[
146
+ self.relative_position_index.view(-1)
147
+ ].view(
148
+ self.window_size[0] * self.window_size[1],
149
+ self.window_size[0] * self.window_size[1],
150
+ -1,
151
+ )
152
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
153
  attn = attn + relative_position_bias.unsqueeze(0)
154
  attn = self.softmax(attn)
 
158
  x = self.proj_drop(x)
159
  return x
160
 
161
+
162
  def _to_channel_last(x):
163
  """
164
  Args:
 
169
  """
170
  return x.permute(0, 2, 3, 1)
171
 
172
+
173
  def _to_channel_first(x):
174
  return x.permute(0, 3, 1, 2)
175
 
176
+
177
  def _to_query(x, N, num_heads, dim_head):
178
  B = x.shape[0]
179
  x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4)
180
  return x
181
 
182
+
183
  class Mlp(nn.Module):
184
  """
185
  Multi-Layer Perceptron (MLP) block
186
  """
187
 
188
+ def __init__(
189
+ self,
190
+ in_features,
191
+ hidden_features=None,
192
+ out_features=None,
193
+ act_layer=nn.GELU,
194
+ drop=0.0,
195
+ ):
196
  """
197
  Args:
198
  in_features: input features dimension.
 
218
  x = self.drop(x)
219
  return x
220
 
221
+
222
  def window_reverse(windows, window_size, H, W, h_w, w_w):
223
  """
224
  Args:
 
235
  x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
236
  return x
237
 
238
+
239
  class feedforward(nn.Module):
240
+ def __init__(
241
+ self,
242
+ dim,
243
+ window_size,
244
+ mlp_ratio=4.0,
245
+ act_layer=nn.GELU,
246
+ drop=0.0,
247
+ drop_path=0.0,
248
+ layer_scale=None,
249
+ ):
250
  super(feedforward, self).__init__()
251
  if layer_scale is not None and type(layer_scale) in [int, float]:
252
  self.layer_scale = True
253
+ self.gamma1 = nn.Parameter(
254
+ layer_scale * torch.ones(dim), requires_grad=True
255
+ )
256
+ self.gamma2 = nn.Parameter(
257
+ layer_scale * torch.ones(dim), requires_grad=True
258
+ )
259
  else:
260
  self.gamma1 = 1.0
261
  self.gamma2 = 1.0
262
  self.window_size = window_size
263
+ self.mlp = Mlp(
264
+ in_features=dim,
265
+ hidden_features=int(dim * mlp_ratio),
266
+ act_layer=act_layer,
267
+ drop=drop,
268
+ )
269
  self.norm = nn.LayerNorm(dim)
270
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
271
+
272
  def forward(self, attn_windows, shortcut):
273
  B, H, W, C = shortcut.shape
274
  h_w = int(torch.div(H, self.window_size).item())
 
278
  x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x)))
279
  return x
280
 
281
+
282
  class pyramid_trans_expr2(nn.Module):
283
+ def __init__(
284
+ self,
285
+ img_size=224,
286
+ num_classes=8,
287
+ window_size=[28, 14, 7],
288
+ num_heads=[2, 4, 8],
289
+ dims=[64, 128, 256],
290
+ embed_dim=768,
291
+ ):
292
  super().__init__()
293
 
294
  self.img_size = img_size
 
300
  self.window_size = window_size
301
  self.N = [win * win for win in window_size]
302
  self.face_landback = MobileFaceNet([112, 112], 136)
303
+ face_landback_checkpoint = torch.load(
304
+ r"./pretrain/mobilefacenet_model_best.pth.tar",
305
+ map_location=lambda storage, loc: storage,
306
+ )
307
+ self.face_landback.load_state_dict(face_landback_checkpoint["state_dict"])
308
 
309
  for param in self.face_landback.parameters():
310
  param.requires_grad = False
311
 
312
+ self.VIT = VisionTransformer(
313
+ depth=2, embed_dim=embed_dim, num_classes=num_classes
314
+ )
315
 
316
+ self.ir_back = Backbone(50, 0.0, "ir")
317
+ ir_checkpoint = torch.load(
318
+ r"./pretrain/ir50.pth", map_location=lambda storage, loc: storage
319
+ )
320
 
321
  self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
322
 
323
+ self.attn1 = WindowAttentionGlobal(
324
+ dim=dims[0], num_heads=num_heads[0], window_size=window_size[0]
325
+ )
326
+ self.attn2 = WindowAttentionGlobal(
327
+ dim=dims[1], num_heads=num_heads[1], window_size=window_size[1]
328
+ )
329
+ self.attn3 = WindowAttentionGlobal(
330
+ dim=dims[2], num_heads=num_heads[2], window_size=window_size[2]
331
+ )
332
  self.window1 = window(window_size=window_size[0], dim=dims[0])
333
  self.window2 = window(window_size=window_size[1], dim=dims[1])
334
  self.window3 = window(window_size=window_size[2], dim=dims[2])
335
+ self.conv1 = nn.Conv2d(
336
+ in_channels=dims[0],
337
+ out_channels=dims[0],
338
+ kernel_size=3,
339
+ stride=2,
340
+ padding=1,
341
+ )
342
+ self.conv2 = nn.Conv2d(
343
+ in_channels=dims[1],
344
+ out_channels=dims[1],
345
+ kernel_size=3,
346
+ stride=2,
347
+ padding=1,
348
+ )
349
+ self.conv3 = nn.Conv2d(
350
+ in_channels=dims[2],
351
+ out_channels=dims[2],
352
+ kernel_size=3,
353
+ stride=2,
354
+ padding=1,
355
+ )
356
 
357
  dpr = [x.item() for x in torch.linspace(0, 0.5, 5)]
358
+ self.ffn1 = feedforward(
359
+ dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0]
360
+ )
361
+ self.ffn2 = feedforward(
362
+ dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1]
363
+ )
364
+ self.ffn3 = feedforward(
365
+ dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2]
366
+ )
367
+
368
+ self.last_face_conv = nn.Conv2d(
369
+ in_channels=512, out_channels=256, kernel_size=3, padding=1
370
+ )
371
+
372
+ self.embed_q = nn.Sequential(
373
+ nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1),
374
+ nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1),
375
+ )
376
+ self.embed_k = nn.Sequential(
377
+ nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1)
378
+ )
379
  self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768)
380
 
381
  def forward(self, x):
382
  x_face = F.interpolate(x, size=112)
383
+ x_face1, x_face2, x_face3 = self.face_landback(x_face)
384
  x_face3 = self.last_face_conv(x_face3)
385
+ x_face1, x_face2, x_face3 = (
386
+ _to_channel_last(x_face1),
387
+ _to_channel_last(x_face2),
388
+ _to_channel_last(x_face3),
389
+ )
390
+
391
+ q1, q2, q3 = (
392
+ _to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]),
393
+ _to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]),
394
+ _to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2]),
395
+ )
396
 
397
  x_ir1, x_ir2, x_ir3 = self.ir_back(x)
398
  x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3)
 
400
  x_window2, shortcut2 = self.window2(x_ir2)
401
  x_window3, shortcut3 = self.window3(x_ir3)
402
 
403
+ o1, o2, o3 = (
404
+ self.attn1(x_window1, q1),
405
+ self.attn2(x_window2, q2),
406
+ self.attn3(x_window3, q3),
407
+ )
408
 
409
+ o1, o2, o3 = (
410
+ self.ffn1(o1, shortcut1),
411
+ self.ffn2(o2, shortcut2),
412
+ self.ffn3(o3, shortcut3),
413
+ )
414
 
415
  o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3)
416
 
417
+ o1, o2, o3 = (
418
+ self.embed_q(o1).flatten(2).transpose(1, 2),
419
+ self.embed_k(o2).flatten(2).transpose(1, 2),
420
+ self.embed_v(o3),
421
+ )
422
 
423
  o = torch.cat([o1, o2, o3], dim=1)
424
 
425
  out = self.VIT(o)
426
  return out
427
 
428
+
429
  def compute_param_flop():
430
  model = pyramid_trans_expr2()
431
+ img = torch.rand(size=(1, 3, 224, 224))
432
  flops, params = profile(model, inputs=(img,))
433
+ print(f"flops:{flops/1000**3}G,params:{params/1000**2}M")
FER/models/vit_model.py CHANGED
@@ -2,6 +2,7 @@
2
  original code from rwightman:
3
  https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
4
  """
 
5
  from functools import partial
6
  from collections import OrderedDict
7
 
@@ -23,16 +24,24 @@ import torch.hub
23
  from functools import partial
24
  import math
25
 
26
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
- from timm.models.registry import register_model
28
  from timm.models.vision_transformer import _cfg, Mlp, Block
29
  # from .ir50 import Backbone
30
 
31
 
32
  def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
33
  """3x3 convolution with padding"""
34
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35
- padding=dilation, groups=groups, bias=False, dilation=dilation)
 
 
 
 
 
 
 
 
36
 
37
 
38
  def conv1x1(in_planes, out_planes, stride=1):
@@ -40,7 +49,7 @@ def conv1x1(in_planes, out_planes, stride=1):
40
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
41
 
42
 
43
- def drop_path(x, drop_prob: float = 0., training: bool = False):
44
  """
45
  Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
  This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
@@ -49,10 +58,12 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
49
  changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
50
  'survival rate' as the argument.
51
  """
52
- if drop_prob == 0. or not training:
53
  return x
54
  keep_prob = 1 - drop_prob
55
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
 
 
56
  random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
57
  random_tensor.floor_() # binarize
58
  output = x.div(keep_prob) * random_tensor
@@ -60,7 +71,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
60
 
61
 
62
  class BasicBlock(nn.Module):
63
- __constants__ = ['downsample']
64
 
65
  def __init__(self, inplanes, planes, stride=1, downsample=None):
66
  super(BasicBlock, self).__init__()
@@ -109,7 +120,9 @@ class PatchEmbed(nn.Module):
109
  2D Image to Patch Embedding
110
  """
111
 
112
- def __init__(self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None):
 
 
113
  super().__init__()
114
  img_size = (img_size, img_size)
115
  patch_size = (patch_size, patch_size)
@@ -135,29 +148,36 @@ class PatchEmbed(nn.Module):
135
 
136
 
137
  class Attention(nn.Module):
138
- def __init__(self,
139
- dim, in_chans, # 输入token的dim
140
- num_heads=8,
141
- qkv_bias=False,
142
- qk_scale=None,
143
- attn_drop_ratio=0.,
144
- proj_drop_ratio=0.):
 
 
 
145
  super(Attention, self).__init__()
146
  self.num_heads = 8
147
  self.img_chanel = in_chans + 1
148
  head_dim = dim // num_heads
149
- self.scale = head_dim ** -0.5
150
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
151
  self.attn_drop = nn.Dropout(attn_drop_ratio)
152
  self.proj = nn.Linear(dim, dim)
153
  self.proj_drop = nn.Dropout(proj_drop_ratio)
154
 
155
  def forward(self, x):
156
- x_img = x[:, :self.img_chanel, :]
157
  # [batch_size, num_patches + 1, total_embed_dim]
158
  B, N, C = x_img.shape
159
  # print(C)
160
- qkv = self.qkv(x_img).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
161
  q, k, v = qkv[0], qkv[1], qkv[2]
162
  # k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
163
  # q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
@@ -193,7 +213,7 @@ class Attention(nn.Module):
193
 
194
 
195
  class AttentionBlock(nn.Module):
196
- __constants__ = ['downsample']
197
 
198
  def __init__(self, inplanes, planes, stride=1, downsample=None):
199
  super(AttentionBlock, self).__init__()
@@ -234,7 +254,14 @@ class Mlp(nn.Module):
234
  MLP as used in Vision Transformer, MLP-Mixer and related networks
235
  """
236
 
237
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
 
 
 
 
 
 
 
238
  super().__init__()
239
  out_features = out_features or in_features
240
  hidden_features = hidden_features or in_features
@@ -253,29 +280,46 @@ class Mlp(nn.Module):
253
 
254
 
255
  class Block(nn.Module):
256
- def __init__(self,
257
- dim, in_chans,
258
- num_heads,
259
- mlp_ratio=4.,
260
- qkv_bias=False,
261
- qk_scale=None,
262
- drop_ratio=0.,
263
- attn_drop_ratio=0.,
264
- drop_path_ratio=0.,
265
- act_layer=nn.GELU,
266
- norm_layer=nn.LayerNorm):
 
 
 
267
  super(Block, self).__init__()
268
  self.norm1 = norm_layer(dim)
269
  self.img_chanel = in_chans + 1
270
 
271
  self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
272
- self.attn = Attention(dim, in_chans=in_chans, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
273
- attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
 
 
 
 
 
 
 
274
  # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
275
- self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
 
 
276
  self.norm2 = norm_layer(dim)
277
  mlp_hidden_dim = int(dim * mlp_ratio)
278
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
 
 
 
 
 
279
 
280
  def forward(self, x):
281
  # x = x + self.drop_path(self.attn(self.norm1(x)))
@@ -308,8 +352,9 @@ class ClassificationHead(nn.Module):
308
 
309
  def load_pretrained_weights(model, checkpoint):
310
  import collections
311
- if 'state_dict' in checkpoint:
312
- state_dict = checkpoint['state_dict']
 
313
  else:
314
  state_dict = checkpoint
315
  model_dict = model.state_dict()
@@ -318,7 +363,7 @@ def load_pretrained_weights(model, checkpoint):
318
  for k, v in state_dict.items():
319
  # If the pretrained state_dict was saved as nn.DataParallel,
320
  # keys would contain "module.", which should be ignored.
321
- if k.startswith('module.'):
322
  k = k[7:]
323
  if k in model_dict and model_dict[k].size() == v.size():
324
  new_state_dict[k] = v
@@ -329,9 +374,10 @@ def load_pretrained_weights(model, checkpoint):
329
  model_dict.update(new_state_dict)
330
 
331
  model.load_state_dict(model_dict)
332
- print('load_weight', len(matched_layers))
333
  return model
334
 
 
335
  class eca_block(nn.Module):
336
  def __init__(self, channel=128, b=1, gamma=2):
337
  super(eca_block, self).__init__()
@@ -339,7 +385,9 @@ class eca_block(nn.Module):
339
  kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
340
 
341
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
342
- self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
 
 
343
  self.sigmoid = nn.Sigmoid()
344
 
345
  def forward(self, x):
@@ -347,6 +395,8 @@ class eca_block(nn.Module):
347
  y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
348
  y = self.sigmoid(y)
349
  return x * y.expand_as(x)
 
 
350
  #
351
  #
352
  # class IR20(nn.Module):
@@ -484,7 +534,9 @@ class eca_block(nn.Module):
484
  kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
485
 
486
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
487
- self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
 
 
488
  self.sigmoid = nn.Sigmoid()
489
 
490
  def forward(self, x):
@@ -493,6 +545,7 @@ class eca_block(nn.Module):
493
  y = self.sigmoid(y)
494
  return x * y.expand_as(x)
495
 
 
496
  class SE_block(nn.Module):
497
  def __init__(self, input_dim: int):
498
  super().__init__()
@@ -511,11 +564,27 @@ class SE_block(nn.Module):
511
 
512
 
513
  class VisionTransformer(nn.Module):
514
- def __init__(self, img_size=14, patch_size=14, in_c=147, num_classes=7,
515
- embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, qkv_bias=True,
516
- qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
517
- attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
518
- act_layer=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  """
520
  Args:
521
  img_size (int, tuple): input image size
@@ -538,7 +607,9 @@ class VisionTransformer(nn.Module):
538
  """
539
  super(VisionTransformer, self).__init__()
540
  self.num_classes = num_classes
541
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
 
 
542
  self.num_tokens = 2 if distilled else 1
543
  norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
544
  act_layer = act_layer or nn.GELU
@@ -549,18 +620,20 @@ class VisionTransformer(nn.Module):
549
 
550
  self.se_block = SE_block(input_dim=embed_dim)
551
 
552
-
553
- self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768)
 
554
  num_patches = self.patch_embed.num_patches
555
  self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
556
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
557
- self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
 
 
558
  # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
559
  self.pos_drop = nn.Dropout(p=drop_ratio)
560
  # self.IR = IR()
561
  self.eca_block = eca_block()
562
 
563
-
564
  # self.ir_back = Backbone(50, 0.0, 'ir')
565
  # ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
566
  # # ir_checkpoint = ir_checkpoint["model"]
@@ -570,24 +643,41 @@ class VisionTransformer(nn.Module):
570
  self.IRLinear1 = nn.Linear(1024, 768)
571
  self.IRLinear2 = nn.Linear(768, 512)
572
  self.eca_block = eca_block()
573
- dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
574
- self.blocks = nn.Sequential(*[
575
- Block(dim=embed_dim, in_chans=in_c, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
576
- qk_scale=qk_scale,
577
- drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
578
- norm_layer=norm_layer, act_layer=act_layer)
579
- for i in range(depth)
580
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  self.norm = norm_layer(embed_dim)
582
 
583
  # Representation layer
584
  if representation_size and not distilled:
585
  self.has_logits = True
586
  self.num_features = representation_size
587
- self.pre_logits = nn.Sequential(OrderedDict([
588
- ("fc", nn.Linear(embed_dim, representation_size)),
589
- ("act", nn.Tanh())
590
- ]))
 
 
 
 
591
  else:
592
  self.has_logits = False
593
  self.pre_logits = nn.Identity()
@@ -596,7 +686,11 @@ class VisionTransformer(nn.Module):
596
  # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
597
  self.head_dist = None
598
  if distilled:
599
- self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
 
 
 
 
600
 
601
  # Weight init
602
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
@@ -616,7 +710,9 @@ class VisionTransformer(nn.Module):
616
  if self.dist_token is None:
617
  x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
618
  else:
619
- x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
 
 
620
  # print(x.shape)
621
  x = self.pos_drop(x + self.pos_embed)
622
  x = self.blocks(x)
@@ -627,7 +723,6 @@ class VisionTransformer(nn.Module):
627
  return x[:, 0], x[:, 1]
628
 
629
  def forward(self, x):
630
-
631
  # B = x.shape[0]
632
  # print(x)
633
  # x = self.eca_block(x)
@@ -680,7 +775,7 @@ def _init_vit_weights(m):
680
  :param m: module
681
  """
682
  if isinstance(m, nn.Linear):
683
- nn.init.trunc_normal_(m.weight, std=.01)
684
  if m.bias is not None:
685
  nn.init.zeros_(m.bias)
686
  elif isinstance(m, nn.Conv2d):
@@ -699,13 +794,15 @@ def vit_base_patch16_224(num_classes: int = 7):
699
  weights ported from official Google JAX impl:
700
  链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
701
  """
702
- model = VisionTransformer(img_size=224,
703
- patch_size=16,
704
- embed_dim=768,
705
- depth=12,
706
- num_heads=12,
707
- representation_size=None,
708
- num_classes=num_classes)
 
 
709
 
710
  return model
711
 
@@ -717,13 +814,15 @@ def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True
717
  weights ported from official Google JAX impl:
718
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
719
  """
720
- model = VisionTransformer(img_size=224,
721
- patch_size=16,
722
- embed_dim=768,
723
- depth=12,
724
- num_heads=12,
725
- representation_size=768 if has_logits else None,
726
- num_classes=num_classes)
 
 
727
  return model
728
 
729
 
@@ -734,13 +833,15 @@ def vit_base_patch32_224(num_classes: int = 1000):
734
  weights ported from official Google JAX impl:
735
  链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
736
  """
737
- model = VisionTransformer(img_size=224,
738
- patch_size=32,
739
- embed_dim=768,
740
- depth=12,
741
- num_heads=12,
742
- representation_size=None,
743
- num_classes=num_classes)
 
 
744
  return model
745
 
746
 
@@ -751,13 +852,15 @@ def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True
751
  weights ported from official Google JAX impl:
752
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
753
  """
754
- model = VisionTransformer(img_size=224,
755
- patch_size=32,
756
- embed_dim=768,
757
- depth=12,
758
- num_heads=12,
759
- representation_size=768 if has_logits else None,
760
- num_classes=num_classes)
 
 
761
  return model
762
 
763
 
@@ -768,13 +871,15 @@ def vit_large_patch16_224(num_classes: int = 1000):
768
  weights ported from official Google JAX impl:
769
  链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
770
  """
771
- model = VisionTransformer(img_size=224,
772
- patch_size=16,
773
- embed_dim=1024,
774
- depth=24,
775
- num_heads=16,
776
- representation_size=None,
777
- num_classes=num_classes)
 
 
778
  return model
779
 
780
 
@@ -785,13 +890,15 @@ def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = Tru
785
  weights ported from official Google JAX impl:
786
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
787
  """
788
- model = VisionTransformer(img_size=224,
789
- patch_size=16,
790
- embed_dim=1024,
791
- depth=24,
792
- num_heads=16,
793
- representation_size=1024 if has_logits else None,
794
- num_classes=num_classes)
 
 
795
  return model
796
 
797
 
@@ -802,13 +909,15 @@ def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = Tru
802
  weights ported from official Google JAX impl:
803
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
804
  """
805
- model = VisionTransformer(img_size=224,
806
- patch_size=32,
807
- embed_dim=1024,
808
- depth=24,
809
- num_heads=16,
810
- representation_size=1024 if has_logits else None,
811
- num_classes=num_classes)
 
 
812
  return model
813
 
814
 
@@ -818,11 +927,13 @@ def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True
818
  ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
819
  NOTE: converted weights not currently available, too large for github release hosting.
820
  """
821
- model = VisionTransformer(img_size=224,
822
- patch_size=14,
823
- embed_dim=1280,
824
- depth=32,
825
- num_heads=16,
826
- representation_size=1280 if has_logits else None,
827
- num_classes=num_classes)
 
 
828
  return model
 
2
  original code from rwightman:
3
  https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
4
  """
5
+
6
  from functools import partial
7
  from collections import OrderedDict
8
 
 
24
  from functools import partial
25
  import math
26
 
27
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
28
+ from timm.models import register_model
29
  from timm.models.vision_transformer import _cfg, Mlp, Block
30
  # from .ir50 import Backbone
31
 
32
 
33
  def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
34
  """3x3 convolution with padding"""
35
+ return nn.Conv2d(
36
+ in_planes,
37
+ out_planes,
38
+ kernel_size=3,
39
+ stride=stride,
40
+ padding=dilation,
41
+ groups=groups,
42
+ bias=False,
43
+ dilation=dilation,
44
+ )
45
 
46
 
47
  def conv1x1(in_planes, out_planes, stride=1):
 
49
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
50
 
51
 
52
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
53
  """
54
  Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
55
  This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
 
58
  changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
59
  'survival rate' as the argument.
60
  """
61
+ if drop_prob == 0.0 or not training:
62
  return x
63
  keep_prob = 1 - drop_prob
64
+ shape = (x.shape[0],) + (1,) * (
65
+ x.ndim - 1
66
+ ) # work with diff dim tensors, not just 2D ConvNets
67
  random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
68
  random_tensor.floor_() # binarize
69
  output = x.div(keep_prob) * random_tensor
 
71
 
72
 
73
  class BasicBlock(nn.Module):
74
+ __constants__ = ["downsample"]
75
 
76
  def __init__(self, inplanes, planes, stride=1, downsample=None):
77
  super(BasicBlock, self).__init__()
 
120
  2D Image to Patch Embedding
121
  """
122
 
123
+ def __init__(
124
+ self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None
125
+ ):
126
  super().__init__()
127
  img_size = (img_size, img_size)
128
  patch_size = (patch_size, patch_size)
 
148
 
149
 
150
  class Attention(nn.Module):
151
+ def __init__(
152
+ self,
153
+ dim,
154
+ in_chans, # 输入token的dim
155
+ num_heads=8,
156
+ qkv_bias=False,
157
+ qk_scale=None,
158
+ attn_drop_ratio=0.0,
159
+ proj_drop_ratio=0.0,
160
+ ):
161
  super(Attention, self).__init__()
162
  self.num_heads = 8
163
  self.img_chanel = in_chans + 1
164
  head_dim = dim // num_heads
165
+ self.scale = head_dim**-0.5
166
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
167
  self.attn_drop = nn.Dropout(attn_drop_ratio)
168
  self.proj = nn.Linear(dim, dim)
169
  self.proj_drop = nn.Dropout(proj_drop_ratio)
170
 
171
  def forward(self, x):
172
+ x_img = x[:, : self.img_chanel, :]
173
  # [batch_size, num_patches + 1, total_embed_dim]
174
  B, N, C = x_img.shape
175
  # print(C)
176
+ qkv = (
177
+ self.qkv(x_img)
178
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
179
+ .permute(2, 0, 3, 1, 4)
180
+ )
181
  q, k, v = qkv[0], qkv[1], qkv[2]
182
  # k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
183
  # q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
 
213
 
214
 
215
  class AttentionBlock(nn.Module):
216
+ __constants__ = ["downsample"]
217
 
218
  def __init__(self, inplanes, planes, stride=1, downsample=None):
219
  super(AttentionBlock, self).__init__()
 
254
  MLP as used in Vision Transformer, MLP-Mixer and related networks
255
  """
256
 
257
+ def __init__(
258
+ self,
259
+ in_features,
260
+ hidden_features=None,
261
+ out_features=None,
262
+ act_layer=nn.GELU,
263
+ drop=0.0,
264
+ ):
265
  super().__init__()
266
  out_features = out_features or in_features
267
  hidden_features = hidden_features or in_features
 
280
 
281
 
282
  class Block(nn.Module):
283
+ def __init__(
284
+ self,
285
+ dim,
286
+ in_chans,
287
+ num_heads,
288
+ mlp_ratio=4.0,
289
+ qkv_bias=False,
290
+ qk_scale=None,
291
+ drop_ratio=0.0,
292
+ attn_drop_ratio=0.0,
293
+ drop_path_ratio=0.0,
294
+ act_layer=nn.GELU,
295
+ norm_layer=nn.LayerNorm,
296
+ ):
297
  super(Block, self).__init__()
298
  self.norm1 = norm_layer(dim)
299
  self.img_chanel = in_chans + 1
300
 
301
  self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
302
+ self.attn = Attention(
303
+ dim,
304
+ in_chans=in_chans,
305
+ num_heads=num_heads,
306
+ qkv_bias=qkv_bias,
307
+ qk_scale=qk_scale,
308
+ attn_drop_ratio=attn_drop_ratio,
309
+ proj_drop_ratio=drop_ratio,
310
+ )
311
  # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
312
+ self.drop_path = (
313
+ DropPath(drop_path_ratio) if drop_path_ratio > 0.0 else nn.Identity()
314
+ )
315
  self.norm2 = norm_layer(dim)
316
  mlp_hidden_dim = int(dim * mlp_ratio)
317
+ self.mlp = Mlp(
318
+ in_features=dim,
319
+ hidden_features=mlp_hidden_dim,
320
+ act_layer=act_layer,
321
+ drop=drop_ratio,
322
+ )
323
 
324
  def forward(self, x):
325
  # x = x + self.drop_path(self.attn(self.norm1(x)))
 
352
 
353
  def load_pretrained_weights(model, checkpoint):
354
  import collections
355
+
356
+ if "state_dict" in checkpoint:
357
+ state_dict = checkpoint["state_dict"]
358
  else:
359
  state_dict = checkpoint
360
  model_dict = model.state_dict()
 
363
  for k, v in state_dict.items():
364
  # If the pretrained state_dict was saved as nn.DataParallel,
365
  # keys would contain "module.", which should be ignored.
366
+ if k.startswith("module."):
367
  k = k[7:]
368
  if k in model_dict and model_dict[k].size() == v.size():
369
  new_state_dict[k] = v
 
374
  model_dict.update(new_state_dict)
375
 
376
  model.load_state_dict(model_dict)
377
+ print("load_weight", len(matched_layers))
378
  return model
379
 
380
+
381
  class eca_block(nn.Module):
382
  def __init__(self, channel=128, b=1, gamma=2):
383
  super(eca_block, self).__init__()
 
385
  kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
386
 
387
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
388
+ self.conv = nn.Conv1d(
389
+ 1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False
390
+ )
391
  self.sigmoid = nn.Sigmoid()
392
 
393
  def forward(self, x):
 
395
  y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
396
  y = self.sigmoid(y)
397
  return x * y.expand_as(x)
398
+
399
+
400
  #
401
  #
402
  # class IR20(nn.Module):
 
534
  kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
535
 
536
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
537
+ self.conv = nn.Conv1d(
538
+ 1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False
539
+ )
540
  self.sigmoid = nn.Sigmoid()
541
 
542
  def forward(self, x):
 
545
  y = self.sigmoid(y)
546
  return x * y.expand_as(x)
547
 
548
+
549
  class SE_block(nn.Module):
550
  def __init__(self, input_dim: int):
551
  super().__init__()
 
564
 
565
 
566
  class VisionTransformer(nn.Module):
567
+ def __init__(
568
+ self,
569
+ img_size=14,
570
+ patch_size=14,
571
+ in_c=147,
572
+ num_classes=7,
573
+ embed_dim=768,
574
+ depth=6,
575
+ num_heads=8,
576
+ mlp_ratio=4.0,
577
+ qkv_bias=True,
578
+ qk_scale=None,
579
+ representation_size=None,
580
+ distilled=False,
581
+ drop_ratio=0.0,
582
+ attn_drop_ratio=0.0,
583
+ drop_path_ratio=0.0,
584
+ embed_layer=PatchEmbed,
585
+ norm_layer=None,
586
+ act_layer=None,
587
+ ):
588
  """
589
  Args:
590
  img_size (int, tuple): input image size
 
607
  """
608
  super(VisionTransformer, self).__init__()
609
  self.num_classes = num_classes
610
+ self.num_features = self.embed_dim = (
611
+ embed_dim # num_features for consistency with other models
612
+ )
613
  self.num_tokens = 2 if distilled else 1
614
  norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
615
  act_layer = act_layer or nn.GELU
 
620
 
621
  self.se_block = SE_block(input_dim=embed_dim)
622
 
623
+ self.patch_embed = embed_layer(
624
+ img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768
625
+ )
626
  num_patches = self.patch_embed.num_patches
627
  self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
628
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
629
+ self.dist_token = (
630
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
631
+ )
632
  # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
633
  self.pos_drop = nn.Dropout(p=drop_ratio)
634
  # self.IR = IR()
635
  self.eca_block = eca_block()
636
 
 
637
  # self.ir_back = Backbone(50, 0.0, 'ir')
638
  # ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
639
  # # ir_checkpoint = ir_checkpoint["model"]
 
643
  self.IRLinear1 = nn.Linear(1024, 768)
644
  self.IRLinear2 = nn.Linear(768, 512)
645
  self.eca_block = eca_block()
646
+ dpr = [
647
+ x.item() for x in torch.linspace(0, drop_path_ratio, depth)
648
+ ] # stochastic depth decay rule
649
+ self.blocks = nn.Sequential(
650
+ *[
651
+ Block(
652
+ dim=embed_dim,
653
+ in_chans=in_c,
654
+ num_heads=num_heads,
655
+ mlp_ratio=mlp_ratio,
656
+ qkv_bias=qkv_bias,
657
+ qk_scale=qk_scale,
658
+ drop_ratio=drop_ratio,
659
+ attn_drop_ratio=attn_drop_ratio,
660
+ drop_path_ratio=dpr[i],
661
+ norm_layer=norm_layer,
662
+ act_layer=act_layer,
663
+ )
664
+ for i in range(depth)
665
+ ]
666
+ )
667
  self.norm = norm_layer(embed_dim)
668
 
669
  # Representation layer
670
  if representation_size and not distilled:
671
  self.has_logits = True
672
  self.num_features = representation_size
673
+ self.pre_logits = nn.Sequential(
674
+ OrderedDict(
675
+ [
676
+ ("fc", nn.Linear(embed_dim, representation_size)),
677
+ ("act", nn.Tanh()),
678
+ ]
679
+ )
680
+ )
681
  else:
682
  self.has_logits = False
683
  self.pre_logits = nn.Identity()
 
686
  # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
687
  self.head_dist = None
688
  if distilled:
689
+ self.head_dist = (
690
+ nn.Linear(self.embed_dim, self.num_classes)
691
+ if num_classes > 0
692
+ else nn.Identity()
693
+ )
694
 
695
  # Weight init
696
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
 
710
  if self.dist_token is None:
711
  x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
712
  else:
713
+ x = torch.cat(
714
+ (cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1
715
+ )
716
  # print(x.shape)
717
  x = self.pos_drop(x + self.pos_embed)
718
  x = self.blocks(x)
 
723
  return x[:, 0], x[:, 1]
724
 
725
  def forward(self, x):
 
726
  # B = x.shape[0]
727
  # print(x)
728
  # x = self.eca_block(x)
 
775
  :param m: module
776
  """
777
  if isinstance(m, nn.Linear):
778
+ nn.init.trunc_normal_(m.weight, std=0.01)
779
  if m.bias is not None:
780
  nn.init.zeros_(m.bias)
781
  elif isinstance(m, nn.Conv2d):
 
794
  weights ported from official Google JAX impl:
795
  链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
796
  """
797
+ model = VisionTransformer(
798
+ img_size=224,
799
+ patch_size=16,
800
+ embed_dim=768,
801
+ depth=12,
802
+ num_heads=12,
803
+ representation_size=None,
804
+ num_classes=num_classes,
805
+ )
806
 
807
  return model
808
 
 
814
  weights ported from official Google JAX impl:
815
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
816
  """
817
+ model = VisionTransformer(
818
+ img_size=224,
819
+ patch_size=16,
820
+ embed_dim=768,
821
+ depth=12,
822
+ num_heads=12,
823
+ representation_size=768 if has_logits else None,
824
+ num_classes=num_classes,
825
+ )
826
  return model
827
 
828
 
 
833
  weights ported from official Google JAX impl:
834
  链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
835
  """
836
+ model = VisionTransformer(
837
+ img_size=224,
838
+ patch_size=32,
839
+ embed_dim=768,
840
+ depth=12,
841
+ num_heads=12,
842
+ representation_size=None,
843
+ num_classes=num_classes,
844
+ )
845
  return model
846
 
847
 
 
852
  weights ported from official Google JAX impl:
853
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
854
  """
855
+ model = VisionTransformer(
856
+ img_size=224,
857
+ patch_size=32,
858
+ embed_dim=768,
859
+ depth=12,
860
+ num_heads=12,
861
+ representation_size=768 if has_logits else None,
862
+ num_classes=num_classes,
863
+ )
864
  return model
865
 
866
 
 
871
  weights ported from official Google JAX impl:
872
  链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
873
  """
874
+ model = VisionTransformer(
875
+ img_size=224,
876
+ patch_size=16,
877
+ embed_dim=1024,
878
+ depth=24,
879
+ num_heads=16,
880
+ representation_size=None,
881
+ num_classes=num_classes,
882
+ )
883
  return model
884
 
885
 
 
890
  weights ported from official Google JAX impl:
891
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
892
  """
893
+ model = VisionTransformer(
894
+ img_size=224,
895
+ patch_size=16,
896
+ embed_dim=1024,
897
+ depth=24,
898
+ num_heads=16,
899
+ representation_size=1024 if has_logits else None,
900
+ num_classes=num_classes,
901
+ )
902
  return model
903
 
904
 
 
909
  weights ported from official Google JAX impl:
910
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
911
  """
912
+ model = VisionTransformer(
913
+ img_size=224,
914
+ patch_size=32,
915
+ embed_dim=1024,
916
+ depth=24,
917
+ num_heads=16,
918
+ representation_size=1024 if has_logits else None,
919
+ num_classes=num_classes,
920
+ )
921
  return model
922
 
923
 
 
927
  ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
928
  NOTE: converted weights not currently available, too large for github release hosting.
929
  """
930
+ model = VisionTransformer(
931
+ img_size=224,
932
+ patch_size=14,
933
+ embed_dim=1280,
934
+ depth=32,
935
+ num_heads=16,
936
+ representation_size=1280 if has_logits else None,
937
+ num_classes=num_classes,
938
+ )
939
  return model
FER/models/vit_model_8.py CHANGED
@@ -2,6 +2,7 @@
2
  original code from rwightman:
3
  https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
4
  """
 
5
  from functools import partial
6
  from collections import OrderedDict
7
 
@@ -23,16 +24,24 @@ import torch.hub
23
  from functools import partial
24
  import math
25
 
26
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
- from timm.models.registry import register_model
28
  from timm.models.vision_transformer import _cfg, Mlp, Block
29
  from .ir50 import Backbone
30
 
31
 
32
  def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
33
  """3x3 convolution with padding"""
34
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35
- padding=dilation, groups=groups, bias=False, dilation=dilation)
 
 
 
 
 
 
 
 
36
 
37
 
38
  def conv1x1(in_planes, out_planes, stride=1):
@@ -40,7 +49,7 @@ def conv1x1(in_planes, out_planes, stride=1):
40
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
41
 
42
 
43
- def drop_path(x, drop_prob: float = 0., training: bool = False):
44
  """
45
  Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
  This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
@@ -49,10 +58,12 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
49
  changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
50
  'survival rate' as the argument.
51
  """
52
- if drop_prob == 0. or not training:
53
  return x
54
  keep_prob = 1 - drop_prob
55
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
 
 
56
  random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
57
  random_tensor.floor_() # binarize
58
  output = x.div(keep_prob) * random_tensor
@@ -60,7 +71,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
60
 
61
 
62
  class BasicBlock(nn.Module):
63
- __constants__ = ['downsample']
64
 
65
  def __init__(self, inplanes, planes, stride=1, downsample=None):
66
  super(BasicBlock, self).__init__()
@@ -109,7 +120,9 @@ class PatchEmbed(nn.Module):
109
  2D Image to Patch Embedding
110
  """
111
 
112
- def __init__(self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None):
 
 
113
  super().__init__()
114
  img_size = (img_size, img_size)
115
  patch_size = (patch_size, patch_size)
@@ -135,29 +148,36 @@ class PatchEmbed(nn.Module):
135
 
136
 
137
  class Attention(nn.Module):
138
- def __init__(self,
139
- dim, in_chans, # 输入token的dim
140
- num_heads=8,
141
- qkv_bias=False,
142
- qk_scale=None,
143
- attn_drop_ratio=0.,
144
- proj_drop_ratio=0.):
 
 
 
145
  super(Attention, self).__init__()
146
  self.num_heads = 8
147
  self.img_chanel = in_chans + 1
148
  head_dim = dim // num_heads
149
- self.scale = head_dim ** -0.5
150
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
151
  self.attn_drop = nn.Dropout(attn_drop_ratio)
152
  self.proj = nn.Linear(dim, dim)
153
  self.proj_drop = nn.Dropout(proj_drop_ratio)
154
 
155
  def forward(self, x):
156
- x_img = x[:, :self.img_chanel, :]
157
  # [batch_size, num_patches + 1, total_embed_dim]
158
  B, N, C = x_img.shape
159
  # print(C)
160
- qkv = self.qkv(x_img).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
161
  q, k, v = qkv[0], qkv[1], qkv[2]
162
  # k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
163
  # q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
@@ -193,7 +213,7 @@ class Attention(nn.Module):
193
 
194
 
195
  class AttentionBlock(nn.Module):
196
- __constants__ = ['downsample']
197
 
198
  def __init__(self, inplanes, planes, stride=1, downsample=None):
199
  super(AttentionBlock, self).__init__()
@@ -234,7 +254,14 @@ class Mlp(nn.Module):
234
  MLP as used in Vision Transformer, MLP-Mixer and related networks
235
  """
236
 
237
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
 
 
 
 
 
 
 
238
  super().__init__()
239
  out_features = out_features or in_features
240
  hidden_features = hidden_features or in_features
@@ -253,29 +280,46 @@ class Mlp(nn.Module):
253
 
254
 
255
  class Block(nn.Module):
256
- def __init__(self,
257
- dim, in_chans,
258
- num_heads,
259
- mlp_ratio=4.,
260
- qkv_bias=False,
261
- qk_scale=None,
262
- drop_ratio=0.,
263
- attn_drop_ratio=0.,
264
- drop_path_ratio=0.,
265
- act_layer=nn.GELU,
266
- norm_layer=nn.LayerNorm):
 
 
 
267
  super(Block, self).__init__()
268
  self.norm1 = norm_layer(dim)
269
  self.img_chanel = in_chans + 1
270
 
271
  self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
272
- self.attn = Attention(dim, in_chans=in_chans, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
273
- attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
 
 
 
 
 
 
 
274
  # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
275
- self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
 
 
276
  self.norm2 = norm_layer(dim)
277
  mlp_hidden_dim = int(dim * mlp_ratio)
278
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
 
 
 
 
 
279
 
280
  def forward(self, x):
281
  # x = x + self.drop_path(self.attn(self.norm1(x)))
@@ -308,8 +352,9 @@ class ClassificationHead(nn.Module):
308
 
309
  def load_pretrained_weights(model, checkpoint):
310
  import collections
311
- if 'state_dict' in checkpoint:
312
- state_dict = checkpoint['state_dict']
 
313
  else:
314
  state_dict = checkpoint
315
  model_dict = model.state_dict()
@@ -318,7 +363,7 @@ def load_pretrained_weights(model, checkpoint):
318
  for k, v in state_dict.items():
319
  # If the pretrained state_dict was saved as nn.DataParallel,
320
  # keys would contain "module.", which should be ignored.
321
- if k.startswith('module.'):
322
  k = k[7:]
323
  if k in model_dict and model_dict[k].size() == v.size():
324
  new_state_dict[k] = v
@@ -329,9 +374,10 @@ def load_pretrained_weights(model, checkpoint):
329
  model_dict.update(new_state_dict)
330
 
331
  model.load_state_dict(model_dict)
332
- print('load_weight', len(matched_layers))
333
  return model
334
 
 
335
  class eca_block(nn.Module):
336
  def __init__(self, channel=128, b=1, gamma=2):
337
  super(eca_block, self).__init__()
@@ -339,7 +385,9 @@ class eca_block(nn.Module):
339
  kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
340
 
341
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
342
- self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
 
 
343
  self.sigmoid = nn.Sigmoid()
344
 
345
  def forward(self, x):
@@ -347,6 +395,8 @@ class eca_block(nn.Module):
347
  y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
348
  y = self.sigmoid(y)
349
  return x * y.expand_as(x)
 
 
350
  #
351
  #
352
  # class IR20(nn.Module):
@@ -484,7 +534,9 @@ class eca_block(nn.Module):
484
  kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
485
 
486
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
487
- self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
 
 
488
  self.sigmoid = nn.Sigmoid()
489
 
490
  def forward(self, x):
@@ -493,6 +545,7 @@ class eca_block(nn.Module):
493
  y = self.sigmoid(y)
494
  return x * y.expand_as(x)
495
 
 
496
  class SE_block(nn.Module):
497
  def __init__(self, input_dim: int):
498
  super().__init__()
@@ -511,11 +564,27 @@ class SE_block(nn.Module):
511
 
512
 
513
  class VisionTransformer(nn.Module):
514
- def __init__(self, img_size=14, patch_size=14, in_c=147, num_classes=8,
515
- embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, qkv_bias=True,
516
- qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
517
- attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
518
- act_layer=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  """
520
  Args:
521
  img_size (int, tuple): input image size
@@ -538,7 +607,9 @@ class VisionTransformer(nn.Module):
538
  """
539
  super(VisionTransformer, self).__init__()
540
  self.num_classes = num_classes
541
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
 
 
542
  self.num_tokens = 2 if distilled else 1
543
  norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
544
  act_layer = act_layer or nn.GELU
@@ -549,18 +620,20 @@ class VisionTransformer(nn.Module):
549
 
550
  self.se_block = SE_block(input_dim=embed_dim)
551
 
552
-
553
- self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768)
 
554
  num_patches = self.patch_embed.num_patches
555
  self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
556
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
557
- self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
 
 
558
  # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
559
  self.pos_drop = nn.Dropout(p=drop_ratio)
560
  # self.IR = IR()
561
  self.eca_block = eca_block()
562
 
563
-
564
  # self.ir_back = Backbone(50, 0.0, 'ir')
565
  # ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
566
  # # ir_checkpoint = ir_checkpoint["model"]
@@ -570,24 +643,41 @@ class VisionTransformer(nn.Module):
570
  self.IRLinear1 = nn.Linear(1024, 768)
571
  self.IRLinear2 = nn.Linear(768, 512)
572
  self.eca_block = eca_block()
573
- dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
574
- self.blocks = nn.Sequential(*[
575
- Block(dim=embed_dim, in_chans=in_c, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
576
- qk_scale=qk_scale,
577
- drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
578
- norm_layer=norm_layer, act_layer=act_layer)
579
- for i in range(depth)
580
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  self.norm = norm_layer(embed_dim)
582
 
583
  # Representation layer
584
  if representation_size and not distilled:
585
  self.has_logits = True
586
  self.num_features = representation_size
587
- self.pre_logits = nn.Sequential(OrderedDict([
588
- ("fc", nn.Linear(embed_dim, representation_size)),
589
- ("act", nn.Tanh())
590
- ]))
 
 
 
 
591
  else:
592
  self.has_logits = False
593
  self.pre_logits = nn.Identity()
@@ -596,7 +686,11 @@ class VisionTransformer(nn.Module):
596
  # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
597
  self.head_dist = None
598
  if distilled:
599
- self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
 
 
 
 
600
 
601
  # Weight init
602
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
@@ -616,7 +710,9 @@ class VisionTransformer(nn.Module):
616
  if self.dist_token is None:
617
  x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
618
  else:
619
- x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
 
 
620
  # print(x.shape)
621
  x = self.pos_drop(x + self.pos_embed)
622
  x = self.blocks(x)
@@ -627,7 +723,6 @@ class VisionTransformer(nn.Module):
627
  return x[:, 0], x[:, 1]
628
 
629
  def forward(self, x):
630
-
631
  # B = x.shape[0]
632
  # print(x)
633
  # x = self.eca_block(x)
@@ -680,7 +775,7 @@ def _init_vit_weights(m):
680
  :param m: module
681
  """
682
  if isinstance(m, nn.Linear):
683
- nn.init.trunc_normal_(m.weight, std=.01)
684
  if m.bias is not None:
685
  nn.init.zeros_(m.bias)
686
  elif isinstance(m, nn.Conv2d):
@@ -699,13 +794,15 @@ def vit_base_patch16_224(num_classes: int = 7):
699
  weights ported from official Google JAX impl:
700
  链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
701
  """
702
- model = VisionTransformer(img_size=224,
703
- patch_size=16,
704
- embed_dim=768,
705
- depth=12,
706
- num_heads=12,
707
- representation_size=None,
708
- num_classes=num_classes)
 
 
709
 
710
  return model
711
 
@@ -717,13 +814,15 @@ def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True
717
  weights ported from official Google JAX impl:
718
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
719
  """
720
- model = VisionTransformer(img_size=224,
721
- patch_size=16,
722
- embed_dim=768,
723
- depth=12,
724
- num_heads=12,
725
- representation_size=768 if has_logits else None,
726
- num_classes=num_classes)
 
 
727
  return model
728
 
729
 
@@ -734,13 +833,15 @@ def vit_base_patch32_224(num_classes: int = 1000):
734
  weights ported from official Google JAX impl:
735
  链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
736
  """
737
- model = VisionTransformer(img_size=224,
738
- patch_size=32,
739
- embed_dim=768,
740
- depth=12,
741
- num_heads=12,
742
- representation_size=None,
743
- num_classes=num_classes)
 
 
744
  return model
745
 
746
 
@@ -751,13 +852,15 @@ def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True
751
  weights ported from official Google JAX impl:
752
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
753
  """
754
- model = VisionTransformer(img_size=224,
755
- patch_size=32,
756
- embed_dim=768,
757
- depth=12,
758
- num_heads=12,
759
- representation_size=768 if has_logits else None,
760
- num_classes=num_classes)
 
 
761
  return model
762
 
763
 
@@ -768,13 +871,15 @@ def vit_large_patch16_224(num_classes: int = 1000):
768
  weights ported from official Google JAX impl:
769
  链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
770
  """
771
- model = VisionTransformer(img_size=224,
772
- patch_size=16,
773
- embed_dim=1024,
774
- depth=24,
775
- num_heads=16,
776
- representation_size=None,
777
- num_classes=num_classes)
 
 
778
  return model
779
 
780
 
@@ -785,13 +890,15 @@ def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = Tru
785
  weights ported from official Google JAX impl:
786
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
787
  """
788
- model = VisionTransformer(img_size=224,
789
- patch_size=16,
790
- embed_dim=1024,
791
- depth=24,
792
- num_heads=16,
793
- representation_size=1024 if has_logits else None,
794
- num_classes=num_classes)
 
 
795
  return model
796
 
797
 
@@ -802,13 +909,15 @@ def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = Tru
802
  weights ported from official Google JAX impl:
803
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
804
  """
805
- model = VisionTransformer(img_size=224,
806
- patch_size=32,
807
- embed_dim=1024,
808
- depth=24,
809
- num_heads=16,
810
- representation_size=1024 if has_logits else None,
811
- num_classes=num_classes)
 
 
812
  return model
813
 
814
 
@@ -818,11 +927,13 @@ def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True
818
  ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
819
  NOTE: converted weights not currently available, too large for github release hosting.
820
  """
821
- model = VisionTransformer(img_size=224,
822
- patch_size=14,
823
- embed_dim=1280,
824
- depth=32,
825
- num_heads=16,
826
- representation_size=1280 if has_logits else None,
827
- num_classes=num_classes)
 
 
828
  return model
 
2
  original code from rwightman:
3
  https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
4
  """
5
+
6
  from functools import partial
7
  from collections import OrderedDict
8
 
 
24
  from functools import partial
25
  import math
26
 
27
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
28
+ from timm.models import register_model
29
  from timm.models.vision_transformer import _cfg, Mlp, Block
30
  from .ir50 import Backbone
31
 
32
 
33
  def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
34
  """3x3 convolution with padding"""
35
+ return nn.Conv2d(
36
+ in_planes,
37
+ out_planes,
38
+ kernel_size=3,
39
+ stride=stride,
40
+ padding=dilation,
41
+ groups=groups,
42
+ bias=False,
43
+ dilation=dilation,
44
+ )
45
 
46
 
47
  def conv1x1(in_planes, out_planes, stride=1):
 
49
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
50
 
51
 
52
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
53
  """
54
  Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
55
  This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
 
58
  changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
59
  'survival rate' as the argument.
60
  """
61
+ if drop_prob == 0.0 or not training:
62
  return x
63
  keep_prob = 1 - drop_prob
64
+ shape = (x.shape[0],) + (1,) * (
65
+ x.ndim - 1
66
+ ) # work with diff dim tensors, not just 2D ConvNets
67
  random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
68
  random_tensor.floor_() # binarize
69
  output = x.div(keep_prob) * random_tensor
 
71
 
72
 
73
  class BasicBlock(nn.Module):
74
+ __constants__ = ["downsample"]
75
 
76
  def __init__(self, inplanes, planes, stride=1, downsample=None):
77
  super(BasicBlock, self).__init__()
 
120
  2D Image to Patch Embedding
121
  """
122
 
123
+ def __init__(
124
+ self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None
125
+ ):
126
  super().__init__()
127
  img_size = (img_size, img_size)
128
  patch_size = (patch_size, patch_size)
 
148
 
149
 
150
  class Attention(nn.Module):
151
+ def __init__(
152
+ self,
153
+ dim,
154
+ in_chans, # 输入token的dim
155
+ num_heads=8,
156
+ qkv_bias=False,
157
+ qk_scale=None,
158
+ attn_drop_ratio=0.0,
159
+ proj_drop_ratio=0.0,
160
+ ):
161
  super(Attention, self).__init__()
162
  self.num_heads = 8
163
  self.img_chanel = in_chans + 1
164
  head_dim = dim // num_heads
165
+ self.scale = head_dim**-0.5
166
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
167
  self.attn_drop = nn.Dropout(attn_drop_ratio)
168
  self.proj = nn.Linear(dim, dim)
169
  self.proj_drop = nn.Dropout(proj_drop_ratio)
170
 
171
  def forward(self, x):
172
+ x_img = x[:, : self.img_chanel, :]
173
  # [batch_size, num_patches + 1, total_embed_dim]
174
  B, N, C = x_img.shape
175
  # print(C)
176
+ qkv = (
177
+ self.qkv(x_img)
178
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
179
+ .permute(2, 0, 3, 1, 4)
180
+ )
181
  q, k, v = qkv[0], qkv[1], qkv[2]
182
  # k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
183
  # q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
 
213
 
214
 
215
  class AttentionBlock(nn.Module):
216
+ __constants__ = ["downsample"]
217
 
218
  def __init__(self, inplanes, planes, stride=1, downsample=None):
219
  super(AttentionBlock, self).__init__()
 
254
  MLP as used in Vision Transformer, MLP-Mixer and related networks
255
  """
256
 
257
+ def __init__(
258
+ self,
259
+ in_features,
260
+ hidden_features=None,
261
+ out_features=None,
262
+ act_layer=nn.GELU,
263
+ drop=0.0,
264
+ ):
265
  super().__init__()
266
  out_features = out_features or in_features
267
  hidden_features = hidden_features or in_features
 
280
 
281
 
282
  class Block(nn.Module):
283
+ def __init__(
284
+ self,
285
+ dim,
286
+ in_chans,
287
+ num_heads,
288
+ mlp_ratio=4.0,
289
+ qkv_bias=False,
290
+ qk_scale=None,
291
+ drop_ratio=0.0,
292
+ attn_drop_ratio=0.0,
293
+ drop_path_ratio=0.0,
294
+ act_layer=nn.GELU,
295
+ norm_layer=nn.LayerNorm,
296
+ ):
297
  super(Block, self).__init__()
298
  self.norm1 = norm_layer(dim)
299
  self.img_chanel = in_chans + 1
300
 
301
  self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
302
+ self.attn = Attention(
303
+ dim,
304
+ in_chans=in_chans,
305
+ num_heads=num_heads,
306
+ qkv_bias=qkv_bias,
307
+ qk_scale=qk_scale,
308
+ attn_drop_ratio=attn_drop_ratio,
309
+ proj_drop_ratio=drop_ratio,
310
+ )
311
  # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
312
+ self.drop_path = (
313
+ DropPath(drop_path_ratio) if drop_path_ratio > 0.0 else nn.Identity()
314
+ )
315
  self.norm2 = norm_layer(dim)
316
  mlp_hidden_dim = int(dim * mlp_ratio)
317
+ self.mlp = Mlp(
318
+ in_features=dim,
319
+ hidden_features=mlp_hidden_dim,
320
+ act_layer=act_layer,
321
+ drop=drop_ratio,
322
+ )
323
 
324
  def forward(self, x):
325
  # x = x + self.drop_path(self.attn(self.norm1(x)))
 
352
 
353
  def load_pretrained_weights(model, checkpoint):
354
  import collections
355
+
356
+ if "state_dict" in checkpoint:
357
+ state_dict = checkpoint["state_dict"]
358
  else:
359
  state_dict = checkpoint
360
  model_dict = model.state_dict()
 
363
  for k, v in state_dict.items():
364
  # If the pretrained state_dict was saved as nn.DataParallel,
365
  # keys would contain "module.", which should be ignored.
366
+ if k.startswith("module."):
367
  k = k[7:]
368
  if k in model_dict and model_dict[k].size() == v.size():
369
  new_state_dict[k] = v
 
374
  model_dict.update(new_state_dict)
375
 
376
  model.load_state_dict(model_dict)
377
+ print("load_weight", len(matched_layers))
378
  return model
379
 
380
+
381
  class eca_block(nn.Module):
382
  def __init__(self, channel=128, b=1, gamma=2):
383
  super(eca_block, self).__init__()
 
385
  kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
386
 
387
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
388
+ self.conv = nn.Conv1d(
389
+ 1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False
390
+ )
391
  self.sigmoid = nn.Sigmoid()
392
 
393
  def forward(self, x):
 
395
  y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
396
  y = self.sigmoid(y)
397
  return x * y.expand_as(x)
398
+
399
+
400
  #
401
  #
402
  # class IR20(nn.Module):
 
534
  kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
535
 
536
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
537
+ self.conv = nn.Conv1d(
538
+ 1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False
539
+ )
540
  self.sigmoid = nn.Sigmoid()
541
 
542
  def forward(self, x):
 
545
  y = self.sigmoid(y)
546
  return x * y.expand_as(x)
547
 
548
+
549
  class SE_block(nn.Module):
550
  def __init__(self, input_dim: int):
551
  super().__init__()
 
564
 
565
 
566
  class VisionTransformer(nn.Module):
567
+ def __init__(
568
+ self,
569
+ img_size=14,
570
+ patch_size=14,
571
+ in_c=147,
572
+ num_classes=8,
573
+ embed_dim=768,
574
+ depth=6,
575
+ num_heads=8,
576
+ mlp_ratio=4.0,
577
+ qkv_bias=True,
578
+ qk_scale=None,
579
+ representation_size=None,
580
+ distilled=False,
581
+ drop_ratio=0.0,
582
+ attn_drop_ratio=0.0,
583
+ drop_path_ratio=0.0,
584
+ embed_layer=PatchEmbed,
585
+ norm_layer=None,
586
+ act_layer=None,
587
+ ):
588
  """
589
  Args:
590
  img_size (int, tuple): input image size
 
607
  """
608
  super(VisionTransformer, self).__init__()
609
  self.num_classes = num_classes
610
+ self.num_features = self.embed_dim = (
611
+ embed_dim # num_features for consistency with other models
612
+ )
613
  self.num_tokens = 2 if distilled else 1
614
  norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
615
  act_layer = act_layer or nn.GELU
 
620
 
621
  self.se_block = SE_block(input_dim=embed_dim)
622
 
623
+ self.patch_embed = embed_layer(
624
+ img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768
625
+ )
626
  num_patches = self.patch_embed.num_patches
627
  self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
628
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
629
+ self.dist_token = (
630
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
631
+ )
632
  # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
633
  self.pos_drop = nn.Dropout(p=drop_ratio)
634
  # self.IR = IR()
635
  self.eca_block = eca_block()
636
 
 
637
  # self.ir_back = Backbone(50, 0.0, 'ir')
638
  # ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
639
  # # ir_checkpoint = ir_checkpoint["model"]
 
643
  self.IRLinear1 = nn.Linear(1024, 768)
644
  self.IRLinear2 = nn.Linear(768, 512)
645
  self.eca_block = eca_block()
646
+ dpr = [
647
+ x.item() for x in torch.linspace(0, drop_path_ratio, depth)
648
+ ] # stochastic depth decay rule
649
+ self.blocks = nn.Sequential(
650
+ *[
651
+ Block(
652
+ dim=embed_dim,
653
+ in_chans=in_c,
654
+ num_heads=num_heads,
655
+ mlp_ratio=mlp_ratio,
656
+ qkv_bias=qkv_bias,
657
+ qk_scale=qk_scale,
658
+ drop_ratio=drop_ratio,
659
+ attn_drop_ratio=attn_drop_ratio,
660
+ drop_path_ratio=dpr[i],
661
+ norm_layer=norm_layer,
662
+ act_layer=act_layer,
663
+ )
664
+ for i in range(depth)
665
+ ]
666
+ )
667
  self.norm = norm_layer(embed_dim)
668
 
669
  # Representation layer
670
  if representation_size and not distilled:
671
  self.has_logits = True
672
  self.num_features = representation_size
673
+ self.pre_logits = nn.Sequential(
674
+ OrderedDict(
675
+ [
676
+ ("fc", nn.Linear(embed_dim, representation_size)),
677
+ ("act", nn.Tanh()),
678
+ ]
679
+ )
680
+ )
681
  else:
682
  self.has_logits = False
683
  self.pre_logits = nn.Identity()
 
686
  # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
687
  self.head_dist = None
688
  if distilled:
689
+ self.head_dist = (
690
+ nn.Linear(self.embed_dim, self.num_classes)
691
+ if num_classes > 0
692
+ else nn.Identity()
693
+ )
694
 
695
  # Weight init
696
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
 
710
  if self.dist_token is None:
711
  x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
712
  else:
713
+ x = torch.cat(
714
+ (cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1
715
+ )
716
  # print(x.shape)
717
  x = self.pos_drop(x + self.pos_embed)
718
  x = self.blocks(x)
 
723
  return x[:, 0], x[:, 1]
724
 
725
  def forward(self, x):
 
726
  # B = x.shape[0]
727
  # print(x)
728
  # x = self.eca_block(x)
 
775
  :param m: module
776
  """
777
  if isinstance(m, nn.Linear):
778
+ nn.init.trunc_normal_(m.weight, std=0.01)
779
  if m.bias is not None:
780
  nn.init.zeros_(m.bias)
781
  elif isinstance(m, nn.Conv2d):
 
794
  weights ported from official Google JAX impl:
795
  链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
796
  """
797
+ model = VisionTransformer(
798
+ img_size=224,
799
+ patch_size=16,
800
+ embed_dim=768,
801
+ depth=12,
802
+ num_heads=12,
803
+ representation_size=None,
804
+ num_classes=num_classes,
805
+ )
806
 
807
  return model
808
 
 
814
  weights ported from official Google JAX impl:
815
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
816
  """
817
+ model = VisionTransformer(
818
+ img_size=224,
819
+ patch_size=16,
820
+ embed_dim=768,
821
+ depth=12,
822
+ num_heads=12,
823
+ representation_size=768 if has_logits else None,
824
+ num_classes=num_classes,
825
+ )
826
  return model
827
 
828
 
 
833
  weights ported from official Google JAX impl:
834
  链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
835
  """
836
+ model = VisionTransformer(
837
+ img_size=224,
838
+ patch_size=32,
839
+ embed_dim=768,
840
+ depth=12,
841
+ num_heads=12,
842
+ representation_size=None,
843
+ num_classes=num_classes,
844
+ )
845
  return model
846
 
847
 
 
852
  weights ported from official Google JAX impl:
853
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
854
  """
855
+ model = VisionTransformer(
856
+ img_size=224,
857
+ patch_size=32,
858
+ embed_dim=768,
859
+ depth=12,
860
+ num_heads=12,
861
+ representation_size=768 if has_logits else None,
862
+ num_classes=num_classes,
863
+ )
864
  return model
865
 
866
 
 
871
  weights ported from official Google JAX impl:
872
  链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
873
  """
874
+ model = VisionTransformer(
875
+ img_size=224,
876
+ patch_size=16,
877
+ embed_dim=1024,
878
+ depth=24,
879
+ num_heads=16,
880
+ representation_size=None,
881
+ num_classes=num_classes,
882
+ )
883
  return model
884
 
885
 
 
890
  weights ported from official Google JAX impl:
891
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
892
  """
893
+ model = VisionTransformer(
894
+ img_size=224,
895
+ patch_size=16,
896
+ embed_dim=1024,
897
+ depth=24,
898
+ num_heads=16,
899
+ representation_size=1024 if has_logits else None,
900
+ num_classes=num_classes,
901
+ )
902
  return model
903
 
904
 
 
909
  weights ported from official Google JAX impl:
910
  https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
911
  """
912
+ model = VisionTransformer(
913
+ img_size=224,
914
+ patch_size=32,
915
+ embed_dim=1024,
916
+ depth=24,
917
+ num_heads=16,
918
+ representation_size=1024 if has_logits else None,
919
+ num_classes=num_classes,
920
+ )
921
  return model
922
 
923
 
 
927
  ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
928
  NOTE: converted weights not currently available, too large for github release hosting.
929
  """
930
+ model = VisionTransformer(
931
+ img_size=224,
932
+ patch_size=14,
933
+ embed_dim=1280,
934
+ depth=32,
935
+ num_heads=16,
936
+ representation_size=1280 if has_logits else None,
937
+ num_classes=num_classes,
938
+ )
939
  return model
FER/prediction.py CHANGED
@@ -48,7 +48,7 @@ def main():
48
  )
49
  )
50
  else:
51
- print("=> no checkpoint found at '{}'".format(model_path))
52
  predict(model, image_path=image_arr)
53
  return
54
 
 
48
  )
49
  )
50
  else:
51
+ print("[!] prediction.py => no checkpoint found at '{}'".format(model_path))
52
  predict(model, image_path=image_arr)
53
  return
54