Srikumar26 commited on
Commit
792e911
1 Parent(s): fdf4051

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +388 -0
model.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # References:
3
+ # MAE: https://github.com/facebookresearch/mae
4
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
+ # DeiT: https://github.com/facebookresearch/deit
6
+ # --------------------------------------------------------
7
+
8
+ from functools import partial
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from timm.models.vision_transformer import PatchEmbed, Block
14
+ from huggingface_hub import PyTorchModelHubMixin
15
+
16
+
17
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
18
+ """
19
+ grid_size: int of the grid height and width
20
+ return:
21
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
22
+ """
23
+ grid_h = np.arange(grid_size, dtype=np.float32)
24
+ grid_w = np.arange(grid_size, dtype=np.float32)
25
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
26
+ grid = np.stack(grid, axis=0)
27
+
28
+ grid = grid.reshape([2, 1, grid_size, grid_size])
29
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
30
+ if cls_token:
31
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
32
+ return pos_embed
33
+
34
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
35
+ assert embed_dim % 2 == 0
36
+
37
+ # use half of dimensions to encode grid_h
38
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
39
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
40
+
41
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
42
+ return emb
43
+
44
+
45
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
46
+ """
47
+ embed_dim: output dimension for each position
48
+ pos: a list of positions to be encoded: size (M,)
49
+ out: (M, D)
50
+ """
51
+ assert embed_dim % 2 == 0
52
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
53
+ omega /= embed_dim / 2.
54
+ omega = 1. / 10000**omega # (D/2,)
55
+
56
+ pos = pos.reshape(-1) # (M,)
57
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
58
+
59
+ emb_sin = np.sin(out) # (M, D/2)
60
+ emb_cos = np.cos(out) # (M, D/2)
61
+
62
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
63
+ return emb
64
+
65
+
66
+ ################################################################################
67
+ # Upsample Block Modules
68
+ ################################################################################
69
+ class LayerNorm(nn.Module):
70
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
71
+ super().__init__()
72
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
73
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
74
+ self.eps = eps
75
+ self.data_format = data_format
76
+ if self.data_format not in ["channels_last", "channels_first"]:
77
+ raise NotImplementedError
78
+ self.normalized_shape = (normalized_shape,)
79
+
80
+ def forward(self, x):
81
+ if self.data_format == "channels_last":
82
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
83
+ elif self.data_format == "channels_first":
84
+ u = x.mean(1, keepdim=True)
85
+ s = (x - u).pow(2).mean(1, keepdim=True)
86
+ x = (x - u) / torch.sqrt(s + self.eps)
87
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
88
+ return x
89
+
90
+ class ResidualBlock(torch.nn.Module):
91
+ """
92
+ Utilized in upsample block
93
+ """
94
+ def __init__(self, channels):
95
+ super(ResidualBlock, self).__init__()
96
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
97
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
98
+ self.relu = nn.ReLU()
99
+
100
+ def forward(self, x):
101
+ """
102
+ x: tensor of shape (B,C,H,W)
103
+ """
104
+ residual = x
105
+ out = self.relu(self.conv1(x))
106
+ out = self.conv2(out) * 0.5
107
+ out = out + residual
108
+
109
+ return out
110
+
111
+
112
+ class UpsampleBlock(nn.Module):
113
+ def __init__(self, in_channels, out_channels):
114
+ super(UpsampleBlock, self).__init__()
115
+
116
+ self.up_conv = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
117
+ self.up_norm = LayerNorm(in_channels, eps=1e-6, data_format="channels_first")
118
+
119
+ self.res_block = ResidualBlock(in_channels)
120
+ self.res_norm = LayerNorm(in_channels, eps=1e-6, data_format="channels_first")
121
+
122
+ self.proj_out = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
123
+
124
+ self.apply(self._init_weights)
125
+
126
+ def _init_weights(self, m):
127
+ if isinstance(m, nn.Linear):
128
+ torch.nn.init.xavier_uniform_(m.weight)
129
+ if isinstance(m, nn.Linear) and m.bias is not None:
130
+ nn.init.constant_(m.bias, 0)
131
+
132
+ elif isinstance(m, nn.Conv2d):
133
+ nn.init.constant_(m.bias, 0)
134
+ nn.init.xavier_uniform_(m.weight)
135
+
136
+ elif isinstance(m, nn.LayerNorm):
137
+ nn.init.constant_(m.bias, 0)
138
+ nn.init.constant_(m.weight, 1.0)
139
+
140
+ def forward(self, x):
141
+ ## upsample 2x
142
+ x = self.up_conv(x)
143
+ x = self.up_norm(x)
144
+ x = torch.nn.functional.leaky_relu(x)
145
+
146
+ # residual block
147
+ x = self.res_block(x)
148
+ x = self.res_norm(x)
149
+
150
+ out = self.proj_out(x)
151
+
152
+ return x, out
153
+
154
+ ################################################################################
155
+
156
+ class MaskedAutoencoderViT(nn.Module, PyTorchModelHubMixin):
157
+ """ Masked Autoencoder with VisionTransformer backbone
158
+ """
159
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
160
+ embed_dim=1024, depth=24, num_heads=16,
161
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
162
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
163
+ proj_ratio=4):
164
+ super().__init__()
165
+
166
+ self.in_c = in_chans
167
+
168
+ ######################################################
169
+ # create upsample block layers
170
+ ms_dim = self.in_c*proj_ratio
171
+ self.proj_up_conv = nn.Conv2d(self.in_c, ms_dim, kernel_size=1, stride=1, padding=0)
172
+ self.proj_up_norm = LayerNorm(ms_dim, eps=1e-6, data_format="channels_first")
173
+
174
+ self.up_block = UpsampleBlock(ms_dim, self.in_c)
175
+
176
+ ######################################################
177
+ # MAE encoder specifics
178
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
179
+ num_patches = self.patch_embed.num_patches
180
+
181
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
182
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
183
+
184
+ self.blocks = nn.ModuleList([
185
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
186
+ for i in range(depth)])
187
+ self.norm = norm_layer(embed_dim)
188
+
189
+ ######################################################
190
+ # MAE decoder specifics
191
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
192
+
193
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
194
+
195
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
196
+
197
+ self.decoder_blocks = nn.ModuleList([
198
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
199
+ for i in range(decoder_depth)])
200
+
201
+ self.decoder_norm = norm_layer(decoder_embed_dim)
202
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
203
+
204
+
205
+ self.norm_pix_loss = norm_pix_loss
206
+
207
+ self.initialize_weights()
208
+
209
+ def initialize_weights(self):
210
+ # initialization
211
+ # initialize (and freeze) pos_embed by sin-cos embedding
212
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
213
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
214
+
215
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
216
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
217
+
218
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
219
+ w = self.patch_embed.proj.weight.data
220
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
221
+
222
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
223
+ torch.nn.init.normal_(self.cls_token, std=.02)
224
+ torch.nn.init.normal_(self.mask_token, std=.02)
225
+
226
+ # initialize nn.Linear and nn.LayerNorm
227
+ self.apply(self._init_weights)
228
+
229
+ def _init_weights(self, m):
230
+ if isinstance(m, nn.Linear):
231
+ # we use xavier_uniform following official JAX ViT:
232
+ torch.nn.init.xavier_uniform_(m.weight)
233
+ if isinstance(m, nn.Linear) and m.bias is not None:
234
+ nn.init.constant_(m.bias, 0)
235
+ elif isinstance(m, nn.LayerNorm):
236
+ nn.init.constant_(m.bias, 0)
237
+ nn.init.constant_(m.weight, 1.0)
238
+
239
+
240
+ def patchify(self, imgs, p, c):
241
+ """
242
+ imgs: (N, C, H, W)
243
+ p: Patch embed patch size
244
+ c: Number of channels
245
+ x: (N, L, patch_size**2 *C)
246
+ """
247
+ # p = self.patch_embed.patch_size[0]
248
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
249
+
250
+ # c = self.in_c
251
+ h = w = imgs.shape[2] // p
252
+ x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p))
253
+ x = torch.einsum('nchpwq->nhwpqc', x)
254
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * c))
255
+ return x
256
+
257
+ def unpatchify(self, x, p, c):
258
+ """
259
+ x: (N, L, patch_size**2 *C)
260
+ p: Patch embed patch size
261
+ c: Number of channels
262
+ imgs: (N, C, H, W)
263
+ """
264
+ h = w = int(x.shape[1]**.5)
265
+ assert h * w == x.shape[1]
266
+
267
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
268
+ x = torch.einsum('nhwpqc->nchpwq', x)
269
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
270
+ return imgs
271
+
272
+ def random_masking(self, x, mask_ratio):
273
+ """
274
+ Perform per-sample random masking by per-sample shuffling.
275
+ Per-sample shuffling is done by argsort random noise.
276
+ x: [N, L, D], sequence
277
+ """
278
+ N, L, D = x.shape # batch, length, dim
279
+ len_keep = int(L * (1 - mask_ratio))
280
+
281
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
282
+
283
+ # sort noise for each sample
284
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
285
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
286
+
287
+ # keep the first subset
288
+ ids_keep = ids_shuffle[:, :len_keep]
289
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
290
+
291
+ # generate the binary mask: 0 is keep, 1 is remove
292
+ mask = torch.ones([N, L], device=x.device)
293
+ mask[:, :len_keep] = 0
294
+ # unshuffle to get the binary mask
295
+ mask = torch.gather(mask, dim=1, index=ids_restore)
296
+
297
+ return x_masked, mask, ids_restore
298
+
299
+ def forward_encoder(self, x, mask_ratio):
300
+ # embed patches
301
+ x = self.patch_embed(x)
302
+
303
+ # add pos embed w/o cls token
304
+ x = x + self.pos_embed[:, 1:, :]
305
+
306
+ # masking: length -> length * mask_ratio
307
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
308
+
309
+ # append cls token
310
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
311
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
312
+ x = torch.cat((cls_tokens, x), dim=1)
313
+
314
+ # apply Transformer blocks
315
+ for blk in self.blocks:
316
+ x = blk(x)
317
+ x = self.norm(x)
318
+
319
+ return x, mask, ids_restore
320
+
321
+ def forward_decoder(self, x, ids_restore):
322
+ # embed tokens
323
+ x = self.decoder_embed(x)
324
+
325
+ # append mask tokens to sequence
326
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
327
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
328
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
329
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
330
+
331
+ # add pos embed
332
+ x = x + self.decoder_pos_embed
333
+
334
+ # apply Transformer blocks
335
+ for blk in self.decoder_blocks:
336
+ x = blk(x)
337
+ x = self.decoder_norm(x)
338
+
339
+ # predictor projection
340
+ x = self.decoder_pred(x)
341
+
342
+ # remove cls token
343
+ x = x[:, 1:, :]
344
+
345
+ return x
346
+
347
+ def forward_loss(self, imgs, pred, mask):
348
+ """
349
+ imgs: [N, 3, H, W]
350
+ pred: [N, L, p*p*3]
351
+ mask: [N, L], 0 is keep, 1 is remove,
352
+ """
353
+ target = self.patchify(imgs, self.patch_embed.patch_size[0], self.in_c)
354
+
355
+ if self.norm_pix_loss:
356
+ mean = target.mean(dim=-1, keepdim=True)
357
+ var = target.var(dim=-1, keepdim=True)
358
+ target = (target - mean) / (var + 1.e-6)**.5
359
+
360
+ loss = (pred - target) ** 2
361
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
362
+
363
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
364
+ return loss
365
+
366
+ def forward_multiscale(self, x):
367
+ """
368
+ x: (N, L, p*p*3)
369
+ """
370
+ x = self.unpatchify(x, self.patch_embed.patch_size[0], self.in_c)
371
+
372
+ x = self.proj_up_conv(x)
373
+ x = F.gelu(x)
374
+ x = self.proj_up_norm(x)
375
+
376
+ _, x = self.up_block(x)
377
+
378
+ return x
379
+
380
+ def forward(self, imgs, imgs_up, mask_ratio=0.75):
381
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
382
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
383
+ pred_ms = self.forward_multiscale(pred)
384
+
385
+ loss = self.forward_loss(imgs, pred, mask) # MSE loss
386
+ ms_loss = F.l1_loss(pred_ms, imgs_up) # compute multiscale loss (L1 loss)
387
+
388
+ return loss, ms_loss, pred, mask