jwyang commited on
Commit
eb1d5d5
1 Parent(s): 8424dda

support arbitary size

Browse files
app.py CHANGED
@@ -118,13 +118,13 @@ def recognize_image(image, texts):
118
  text_embeddings = model.get_text_embeddings(texts.split(';'))
119
 
120
  # compute output
121
- feat_img, feat_map = model.encode_image(img_t.unsqueeze(0), output_map=True)
122
  output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
123
  prediction = output.softmax(-1).flatten()
124
 
125
  # generate feat map given the top matched texts
126
  output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
127
- output_map = output_map.view(1, 1, 7, 7)
128
 
129
  output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
130
  output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
@@ -142,10 +142,10 @@ gr.Interface(
142
  fn=recognize_image,
143
  inputs=["image", "text"],
144
  outputs=[
145
- label,
146
  gr.outputs.Image(
147
  type="pil",
148
  label="zero-shot heat map"),
 
149
  ],
150
  examples=[
151
  ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
 
118
  text_embeddings = model.get_text_embeddings(texts.split(';'))
119
 
120
  # compute output
121
+ feat_img, feat_map, H, W = model.encode_image(img_t.unsqueeze(0), output_map=True)
122
  output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
123
  prediction = output.softmax(-1).flatten()
124
 
125
  # generate feat map given the top matched texts
126
  output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
127
+ output_map = output_map.view(1, 1, H, W)
128
 
129
  output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
130
  output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
 
142
  fn=recognize_image,
143
  inputs=["image", "text"],
144
  outputs=[
 
145
  gr.outputs.Image(
146
  type="pil",
147
  label="zero-shot heat map"),
148
+ label
149
  ],
150
  examples=[
151
  ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
model/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/model/__pycache__/__init__.cpython-39.pyc and b/model/__pycache__/__init__.cpython-39.pyc differ
 
model/__pycache__/model.cpython-39.pyc CHANGED
Binary files a/model/__pycache__/model.cpython-39.pyc and b/model/__pycache__/model.cpython-39.pyc differ
 
model/__pycache__/templates.cpython-39.pyc CHANGED
Binary files a/model/__pycache__/templates.cpython-39.pyc and b/model/__pycache__/templates.cpython-39.pyc differ
 
model/image_encoder/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/model/image_encoder/__pycache__/__init__.cpython-39.pyc and b/model/image_encoder/__pycache__/__init__.cpython-39.pyc differ
 
model/image_encoder/__pycache__/build.cpython-39.pyc CHANGED
Binary files a/model/image_encoder/__pycache__/build.cpython-39.pyc and b/model/image_encoder/__pycache__/build.cpython-39.pyc differ
 
model/image_encoder/__pycache__/focalnet.cpython-39.pyc CHANGED
Binary files a/model/image_encoder/__pycache__/focalnet.cpython-39.pyc and b/model/image_encoder/__pycache__/focalnet.cpython-39.pyc differ
 
model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc CHANGED
Binary files a/model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc and b/model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc differ
 
model/image_encoder/swin_transformer.py CHANGED
@@ -4,9 +4,10 @@
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # Written by Ze Liu
6
  # --------------------------------------------------------
7
-
8
  import torch
9
  import torch.nn as nn
 
10
  import torch.utils.checkpoint as checkpoint
11
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
 
@@ -230,38 +231,51 @@ class SwinTransformerBlock(nn.Module):
230
 
231
  self.register_buffer("attn_mask", attn_mask)
232
 
233
- def forward(self, x):
234
- H, W = self.input_resolution
235
  B, L, C = x.shape
236
- assert L == H * W, "input feature has wrong size"
237
 
238
  shortcut = x
239
  x = self.norm1(x)
240
- x = x.view(B, H, W, C)
 
 
 
 
 
 
 
241
 
242
  # cyclic shift
243
  if self.shift_size > 0:
244
  shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
 
245
  else:
246
  shifted_x = x
 
247
 
248
  # partition windows
249
  x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
250
  x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
251
 
252
  # W-MSA/SW-MSA
253
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
254
 
255
  # merge windows
256
  attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
257
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
258
 
259
  # reverse cyclic shift
260
  if self.shift_size > 0:
261
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
262
  else:
263
  x = shifted_x
264
- x = x.view(B, H * W, C)
 
 
 
 
265
 
266
  # FFN
267
  x = shortcut + self.drop_path(x)
@@ -304,16 +318,20 @@ class PatchMerging(nn.Module):
304
  self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
305
  self.norm = norm_layer(4 * dim)
306
 
307
- def forward(self, x):
308
  """
309
  x: B, H*W, C
310
  """
311
- H, W = self.input_resolution
312
  B, L, C = x.shape
313
- assert L == H * W, "input feature has wrong size"
314
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
315
 
316
- x = x.view(B, H, W, C)
 
 
 
 
 
317
 
318
  x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
319
  x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
@@ -366,6 +384,8 @@ class BasicLayer(nn.Module):
366
  self.input_resolution = input_resolution
367
  self.depth = depth
368
  self.use_checkpoint = use_checkpoint
 
 
369
 
370
  # build blocks
371
  self.blocks = nn.ModuleList([
@@ -385,15 +405,39 @@ class BasicLayer(nn.Module):
385
  else:
386
  self.downsample = None
387
 
388
- def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  for blk in self.blocks:
390
  if self.use_checkpoint:
391
  x = checkpoint.checkpoint(blk, x)
392
  else:
393
- x = blk(x)
394
  if self.downsample is not None:
395
- x = self.downsample(x)
396
- return x
 
397
 
398
  def extra_repr(self) -> str:
399
  return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
@@ -440,12 +484,14 @@ class PatchEmbed(nn.Module):
440
  def forward(self, x):
441
  B, C, H, W = x.shape
442
  # FIXME look at relaxing size constraints
443
- assert H == self.img_size[0] and W == self.img_size[1], \
444
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
445
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
 
 
446
  if self.norm is not None:
447
  x = self.norm(x)
448
- return x
449
 
450
  def flops(self):
451
  Ho, Wo = self.patches_resolution
@@ -558,20 +604,20 @@ class SwinTransformer(nn.Module):
558
  return {'relative_position_bias_table'}
559
 
560
  def forward_features(self, x, output_map=False):
561
- x = self.patch_embed(x)
562
  if self.ape:
563
  x = x + self.absolute_pos_embed
564
  x = self.pos_drop(x)
565
 
566
  for layer in self.layers:
567
- x = layer(x)
568
 
569
  x_map = self.norm(x).transpose(1, 2) # B C L
570
  x = self.avgpool(x_map) # B C 1
571
  x = torch.flatten(x, 1)
572
 
573
  if output_map:
574
- return x, x_map
575
  else:
576
  return x
577
 
 
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # Written by Ze Liu
6
  # --------------------------------------------------------
7
+ import numpy as np
8
  import torch
9
  import torch.nn as nn
10
+ import torch.nn.functional as F
11
  import torch.utils.checkpoint as checkpoint
12
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
 
 
231
 
232
  self.register_buffer("attn_mask", attn_mask)
233
 
234
+ def forward(self, x, Ph, Pw, attn_mask):
235
+ # H, W = self.input_resolution
236
  B, L, C = x.shape
237
+ # assert L == H * W, "input feature has wrong size"
238
 
239
  shortcut = x
240
  x = self.norm1(x)
241
+ x = x.view(B, Ph, Pw, C)
242
+
243
+ # pad feature maps to multiples of window size
244
+ pad_l = pad_t = 0
245
+ pad_r = (self.window_size - Pw % self.window_size) % self.window_size
246
+ pad_b = (self.window_size - Ph % self.window_size) % self.window_size
247
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
248
+ _, Hp, Wp, _ = x.shape
249
 
250
  # cyclic shift
251
  if self.shift_size > 0:
252
  shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
253
+ attn_mask = attn_mask
254
  else:
255
  shifted_x = x
256
+ attn_mask = None
257
 
258
  # partition windows
259
  x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
260
  x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
261
 
262
  # W-MSA/SW-MSA
263
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
264
 
265
  # merge windows
266
  attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
267
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
268
 
269
  # reverse cyclic shift
270
  if self.shift_size > 0:
271
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
272
  else:
273
  x = shifted_x
274
+
275
+ if pad_r > 0 or pad_b > 0:
276
+ x = x[:, :Ph, :Pw, :].contiguous()
277
+
278
+ x = x.view(B, Ph * Pw, C)
279
 
280
  # FFN
281
  x = shortcut + self.drop_path(x)
 
318
  self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
319
  self.norm = norm_layer(4 * dim)
320
 
321
+ def forward(self, x, Ph, Pw):
322
  """
323
  x: B, H*W, C
324
  """
 
325
  B, L, C = x.shape
326
+ # assert L == H * W, "input feature has wrong size"
327
+ # assert Ph % 2 == 0 and Pw % 2 == 0, f"x size ({Ph}*{Pw}) are not even."
328
 
329
+ x = x.view(B, Ph, Pw, C)
330
+
331
+ # padding
332
+ pad_input = (Ph % 2 == 1) or (Pw % 2 == 1)
333
+ if pad_input:
334
+ x = F.pad(x, (0, 0, 0, Pw % 2, 0, Ph % 2))
335
 
336
  x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
337
  x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
 
384
  self.input_resolution = input_resolution
385
  self.depth = depth
386
  self.use_checkpoint = use_checkpoint
387
+ self.window_size = window_size
388
+ self.shift_size = window_size // 2
389
 
390
  # build blocks
391
  self.blocks = nn.ModuleList([
 
405
  else:
406
  self.downsample = None
407
 
408
+ def forward(self, x, Ph, Pw):
409
+
410
+ # calculate attention mask for SW-MSA
411
+ Hp = int(np.ceil(Ph / self.window_size)) * self.window_size
412
+ Wp = int(np.ceil(Pw / self.window_size)) * self.window_size
413
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
414
+ h_slices = (slice(0, -self.window_size),
415
+ slice(-self.window_size, -self.shift_size),
416
+ slice(-self.shift_size, None))
417
+ w_slices = (slice(0, -self.window_size),
418
+ slice(-self.window_size, -self.shift_size),
419
+ slice(-self.shift_size, None))
420
+ cnt = 0
421
+ for h in h_slices:
422
+ for w in w_slices:
423
+ img_mask[:, h, w, :] = cnt
424
+ cnt += 1
425
+
426
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
427
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
428
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
429
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
430
+
431
+
432
  for blk in self.blocks:
433
  if self.use_checkpoint:
434
  x = checkpoint.checkpoint(blk, x)
435
  else:
436
+ x = blk(x, Ph, Pw, attn_mask)
437
  if self.downsample is not None:
438
+ x = self.downsample(x, Ph, Pw)
439
+ Ph, Pw = (Ph + 1) // 2, (Pw + 1) // 2
440
+ return x, Ph, Pw
441
 
442
  def extra_repr(self) -> str:
443
  return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
 
484
  def forward(self, x):
485
  B, C, H, W = x.shape
486
  # FIXME look at relaxing size constraints
487
+ # assert H == self.img_size[0] and W == self.img_size[1], \
488
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
489
+ x = self.proj(x)
490
+ Ph, Pw = x.shape[2:]
491
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
492
  if self.norm is not None:
493
  x = self.norm(x)
494
+ return x, Ph, Pw
495
 
496
  def flops(self):
497
  Ho, Wo = self.patches_resolution
 
604
  return {'relative_position_bias_table'}
605
 
606
  def forward_features(self, x, output_map=False):
607
+ x, Ph, Pw = self.patch_embed(x)
608
  if self.ape:
609
  x = x + self.absolute_pos_embed
610
  x = self.pos_drop(x)
611
 
612
  for layer in self.layers:
613
+ x, Ph, Pw = layer(x, Ph, Pw)
614
 
615
  x_map = self.norm(x).transpose(1, 2) # B C L
616
  x = self.avgpool(x_map) # B C 1
617
  x = torch.flatten(x, 1)
618
 
619
  if output_map:
620
+ return x, x_map, Ph, Pw
621
  else:
622
  return x
623
 
model/model.py CHANGED
@@ -156,7 +156,7 @@ class UniCLModel(nn.Module):
156
  def encode_image(self, image, norm=True, output_map=False):
157
  x = self.image_encoder.forward_features(image, output_map=output_map)
158
  if output_map:
159
- x, x_map = x
160
 
161
  x = x @ self.image_projection
162
 
@@ -169,7 +169,7 @@ class UniCLModel(nn.Module):
169
  x_map = x_map / x_map.norm(dim=1, keepdim=True)
170
 
171
  if output_map:
172
- return x, x_map
173
  else:
174
  return x
175
 
 
156
  def encode_image(self, image, norm=True, output_map=False):
157
  x = self.image_encoder.forward_features(image, output_map=output_map)
158
  if output_map:
159
+ x, x_map, H, W = x
160
 
161
  x = x @ self.image_projection
162
 
 
169
  x_map = x_map / x_map.norm(dim=1, keepdim=True)
170
 
171
  if output_map:
172
+ return x, x_map, H, W
173
  else:
174
  return x
175
 
model/text_encoder/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/model/text_encoder/__pycache__/__init__.cpython-39.pyc and b/model/text_encoder/__pycache__/__init__.cpython-39.pyc differ
 
model/text_encoder/__pycache__/build.cpython-39.pyc CHANGED
Binary files a/model/text_encoder/__pycache__/build.cpython-39.pyc and b/model/text_encoder/__pycache__/build.cpython-39.pyc differ
 
model/text_encoder/__pycache__/hf_model.cpython-39.pyc CHANGED
Binary files a/model/text_encoder/__pycache__/hf_model.cpython-39.pyc and b/model/text_encoder/__pycache__/hf_model.cpython-39.pyc differ
 
model/text_encoder/__pycache__/registry.cpython-39.pyc CHANGED
Binary files a/model/text_encoder/__pycache__/registry.cpython-39.pyc and b/model/text_encoder/__pycache__/registry.cpython-39.pyc differ
 
model/text_encoder/__pycache__/transformer.cpython-39.pyc CHANGED
Binary files a/model/text_encoder/__pycache__/transformer.cpython-39.pyc and b/model/text_encoder/__pycache__/transformer.cpython-39.pyc differ