Spaces:
Build error
Build error
jwyang
commited on
Commit
•
eb1d5d5
1
Parent(s):
8424dda
support arbitary size
Browse files- app.py +3 -3
- model/__pycache__/__init__.cpython-39.pyc +0 -0
- model/__pycache__/model.cpython-39.pyc +0 -0
- model/__pycache__/templates.cpython-39.pyc +0 -0
- model/image_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
- model/image_encoder/__pycache__/build.cpython-39.pyc +0 -0
- model/image_encoder/__pycache__/focalnet.cpython-39.pyc +0 -0
- model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc +0 -0
- model/image_encoder/swin_transformer.py +71 -25
- model/model.py +2 -2
- model/text_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
- model/text_encoder/__pycache__/build.cpython-39.pyc +0 -0
- model/text_encoder/__pycache__/hf_model.cpython-39.pyc +0 -0
- model/text_encoder/__pycache__/registry.cpython-39.pyc +0 -0
- model/text_encoder/__pycache__/transformer.cpython-39.pyc +0 -0
app.py
CHANGED
@@ -118,13 +118,13 @@ def recognize_image(image, texts):
|
|
118 |
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
119 |
|
120 |
# compute output
|
121 |
-
feat_img, feat_map = model.encode_image(img_t.unsqueeze(0), output_map=True)
|
122 |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
123 |
prediction = output.softmax(-1).flatten()
|
124 |
|
125 |
# generate feat map given the top matched texts
|
126 |
output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
|
127 |
-
output_map = output_map.view(1, 1,
|
128 |
|
129 |
output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
|
130 |
output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
|
@@ -142,10 +142,10 @@ gr.Interface(
|
|
142 |
fn=recognize_image,
|
143 |
inputs=["image", "text"],
|
144 |
outputs=[
|
145 |
-
label,
|
146 |
gr.outputs.Image(
|
147 |
type="pil",
|
148 |
label="zero-shot heat map"),
|
|
|
149 |
],
|
150 |
examples=[
|
151 |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
|
|
118 |
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
119 |
|
120 |
# compute output
|
121 |
+
feat_img, feat_map, H, W = model.encode_image(img_t.unsqueeze(0), output_map=True)
|
122 |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
123 |
prediction = output.softmax(-1).flatten()
|
124 |
|
125 |
# generate feat map given the top matched texts
|
126 |
output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
|
127 |
+
output_map = output_map.view(1, 1, H, W)
|
128 |
|
129 |
output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
|
130 |
output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
|
|
|
142 |
fn=recognize_image,
|
143 |
inputs=["image", "text"],
|
144 |
outputs=[
|
|
|
145 |
gr.outputs.Image(
|
146 |
type="pil",
|
147 |
label="zero-shot heat map"),
|
148 |
+
label
|
149 |
],
|
150 |
examples=[
|
151 |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
model/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/model/__pycache__/__init__.cpython-39.pyc and b/model/__pycache__/__init__.cpython-39.pyc differ
|
|
model/__pycache__/model.cpython-39.pyc
CHANGED
Binary files a/model/__pycache__/model.cpython-39.pyc and b/model/__pycache__/model.cpython-39.pyc differ
|
|
model/__pycache__/templates.cpython-39.pyc
CHANGED
Binary files a/model/__pycache__/templates.cpython-39.pyc and b/model/__pycache__/templates.cpython-39.pyc differ
|
|
model/image_encoder/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/model/image_encoder/__pycache__/__init__.cpython-39.pyc and b/model/image_encoder/__pycache__/__init__.cpython-39.pyc differ
|
|
model/image_encoder/__pycache__/build.cpython-39.pyc
CHANGED
Binary files a/model/image_encoder/__pycache__/build.cpython-39.pyc and b/model/image_encoder/__pycache__/build.cpython-39.pyc differ
|
|
model/image_encoder/__pycache__/focalnet.cpython-39.pyc
CHANGED
Binary files a/model/image_encoder/__pycache__/focalnet.cpython-39.pyc and b/model/image_encoder/__pycache__/focalnet.cpython-39.pyc differ
|
|
model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc
CHANGED
Binary files a/model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc and b/model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc differ
|
|
model/image_encoder/swin_transformer.py
CHANGED
@@ -4,9 +4,10 @@
|
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# Written by Ze Liu
|
6 |
# --------------------------------------------------------
|
7 |
-
|
8 |
import torch
|
9 |
import torch.nn as nn
|
|
|
10 |
import torch.utils.checkpoint as checkpoint
|
11 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
|
@@ -230,38 +231,51 @@ class SwinTransformerBlock(nn.Module):
|
|
230 |
|
231 |
self.register_buffer("attn_mask", attn_mask)
|
232 |
|
233 |
-
def forward(self, x):
|
234 |
-
H, W = self.input_resolution
|
235 |
B, L, C = x.shape
|
236 |
-
assert L == H * W, "input feature has wrong size"
|
237 |
|
238 |
shortcut = x
|
239 |
x = self.norm1(x)
|
240 |
-
x = x.view(B,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
# cyclic shift
|
243 |
if self.shift_size > 0:
|
244 |
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
|
|
245 |
else:
|
246 |
shifted_x = x
|
|
|
247 |
|
248 |
# partition windows
|
249 |
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
250 |
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
251 |
|
252 |
# W-MSA/SW-MSA
|
253 |
-
attn_windows = self.attn(x_windows, mask=
|
254 |
|
255 |
# merge windows
|
256 |
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
257 |
-
shifted_x = window_reverse(attn_windows, self.window_size,
|
258 |
|
259 |
# reverse cyclic shift
|
260 |
if self.shift_size > 0:
|
261 |
-
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
262 |
else:
|
263 |
x = shifted_x
|
264 |
-
|
|
|
|
|
|
|
|
|
265 |
|
266 |
# FFN
|
267 |
x = shortcut + self.drop_path(x)
|
@@ -304,16 +318,20 @@ class PatchMerging(nn.Module):
|
|
304 |
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
305 |
self.norm = norm_layer(4 * dim)
|
306 |
|
307 |
-
def forward(self, x):
|
308 |
"""
|
309 |
x: B, H*W, C
|
310 |
"""
|
311 |
-
H, W = self.input_resolution
|
312 |
B, L, C = x.shape
|
313 |
-
assert L == H * W, "input feature has wrong size"
|
314 |
-
assert
|
315 |
|
316 |
-
x = x.view(B,
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
319 |
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
@@ -366,6 +384,8 @@ class BasicLayer(nn.Module):
|
|
366 |
self.input_resolution = input_resolution
|
367 |
self.depth = depth
|
368 |
self.use_checkpoint = use_checkpoint
|
|
|
|
|
369 |
|
370 |
# build blocks
|
371 |
self.blocks = nn.ModuleList([
|
@@ -385,15 +405,39 @@ class BasicLayer(nn.Module):
|
|
385 |
else:
|
386 |
self.downsample = None
|
387 |
|
388 |
-
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
for blk in self.blocks:
|
390 |
if self.use_checkpoint:
|
391 |
x = checkpoint.checkpoint(blk, x)
|
392 |
else:
|
393 |
-
x = blk(x)
|
394 |
if self.downsample is not None:
|
395 |
-
x = self.downsample(x)
|
396 |
-
|
|
|
397 |
|
398 |
def extra_repr(self) -> str:
|
399 |
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
@@ -440,12 +484,14 @@ class PatchEmbed(nn.Module):
|
|
440 |
def forward(self, x):
|
441 |
B, C, H, W = x.shape
|
442 |
# FIXME look at relaxing size constraints
|
443 |
-
assert H == self.img_size[0] and W == self.img_size[1], \
|
444 |
-
|
445 |
-
x = self.proj(x)
|
|
|
|
|
446 |
if self.norm is not None:
|
447 |
x = self.norm(x)
|
448 |
-
return x
|
449 |
|
450 |
def flops(self):
|
451 |
Ho, Wo = self.patches_resolution
|
@@ -558,20 +604,20 @@ class SwinTransformer(nn.Module):
|
|
558 |
return {'relative_position_bias_table'}
|
559 |
|
560 |
def forward_features(self, x, output_map=False):
|
561 |
-
x = self.patch_embed(x)
|
562 |
if self.ape:
|
563 |
x = x + self.absolute_pos_embed
|
564 |
x = self.pos_drop(x)
|
565 |
|
566 |
for layer in self.layers:
|
567 |
-
x = layer(x)
|
568 |
|
569 |
x_map = self.norm(x).transpose(1, 2) # B C L
|
570 |
x = self.avgpool(x_map) # B C 1
|
571 |
x = torch.flatten(x, 1)
|
572 |
|
573 |
if output_map:
|
574 |
-
return x, x_map
|
575 |
else:
|
576 |
return x
|
577 |
|
|
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# Written by Ze Liu
|
6 |
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
import torch.utils.checkpoint as checkpoint
|
12 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
13 |
|
|
|
231 |
|
232 |
self.register_buffer("attn_mask", attn_mask)
|
233 |
|
234 |
+
def forward(self, x, Ph, Pw, attn_mask):
|
235 |
+
# H, W = self.input_resolution
|
236 |
B, L, C = x.shape
|
237 |
+
# assert L == H * W, "input feature has wrong size"
|
238 |
|
239 |
shortcut = x
|
240 |
x = self.norm1(x)
|
241 |
+
x = x.view(B, Ph, Pw, C)
|
242 |
+
|
243 |
+
# pad feature maps to multiples of window size
|
244 |
+
pad_l = pad_t = 0
|
245 |
+
pad_r = (self.window_size - Pw % self.window_size) % self.window_size
|
246 |
+
pad_b = (self.window_size - Ph % self.window_size) % self.window_size
|
247 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
248 |
+
_, Hp, Wp, _ = x.shape
|
249 |
|
250 |
# cyclic shift
|
251 |
if self.shift_size > 0:
|
252 |
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
253 |
+
attn_mask = attn_mask
|
254 |
else:
|
255 |
shifted_x = x
|
256 |
+
attn_mask = None
|
257 |
|
258 |
# partition windows
|
259 |
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
260 |
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
261 |
|
262 |
# W-MSA/SW-MSA
|
263 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
264 |
|
265 |
# merge windows
|
266 |
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
267 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
268 |
|
269 |
# reverse cyclic shift
|
270 |
if self.shift_size > 0:
|
271 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
272 |
else:
|
273 |
x = shifted_x
|
274 |
+
|
275 |
+
if pad_r > 0 or pad_b > 0:
|
276 |
+
x = x[:, :Ph, :Pw, :].contiguous()
|
277 |
+
|
278 |
+
x = x.view(B, Ph * Pw, C)
|
279 |
|
280 |
# FFN
|
281 |
x = shortcut + self.drop_path(x)
|
|
|
318 |
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
319 |
self.norm = norm_layer(4 * dim)
|
320 |
|
321 |
+
def forward(self, x, Ph, Pw):
|
322 |
"""
|
323 |
x: B, H*W, C
|
324 |
"""
|
|
|
325 |
B, L, C = x.shape
|
326 |
+
# assert L == H * W, "input feature has wrong size"
|
327 |
+
# assert Ph % 2 == 0 and Pw % 2 == 0, f"x size ({Ph}*{Pw}) are not even."
|
328 |
|
329 |
+
x = x.view(B, Ph, Pw, C)
|
330 |
+
|
331 |
+
# padding
|
332 |
+
pad_input = (Ph % 2 == 1) or (Pw % 2 == 1)
|
333 |
+
if pad_input:
|
334 |
+
x = F.pad(x, (0, 0, 0, Pw % 2, 0, Ph % 2))
|
335 |
|
336 |
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
337 |
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
|
|
384 |
self.input_resolution = input_resolution
|
385 |
self.depth = depth
|
386 |
self.use_checkpoint = use_checkpoint
|
387 |
+
self.window_size = window_size
|
388 |
+
self.shift_size = window_size // 2
|
389 |
|
390 |
# build blocks
|
391 |
self.blocks = nn.ModuleList([
|
|
|
405 |
else:
|
406 |
self.downsample = None
|
407 |
|
408 |
+
def forward(self, x, Ph, Pw):
|
409 |
+
|
410 |
+
# calculate attention mask for SW-MSA
|
411 |
+
Hp = int(np.ceil(Ph / self.window_size)) * self.window_size
|
412 |
+
Wp = int(np.ceil(Pw / self.window_size)) * self.window_size
|
413 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
414 |
+
h_slices = (slice(0, -self.window_size),
|
415 |
+
slice(-self.window_size, -self.shift_size),
|
416 |
+
slice(-self.shift_size, None))
|
417 |
+
w_slices = (slice(0, -self.window_size),
|
418 |
+
slice(-self.window_size, -self.shift_size),
|
419 |
+
slice(-self.shift_size, None))
|
420 |
+
cnt = 0
|
421 |
+
for h in h_slices:
|
422 |
+
for w in w_slices:
|
423 |
+
img_mask[:, h, w, :] = cnt
|
424 |
+
cnt += 1
|
425 |
+
|
426 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
427 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
428 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
429 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
430 |
+
|
431 |
+
|
432 |
for blk in self.blocks:
|
433 |
if self.use_checkpoint:
|
434 |
x = checkpoint.checkpoint(blk, x)
|
435 |
else:
|
436 |
+
x = blk(x, Ph, Pw, attn_mask)
|
437 |
if self.downsample is not None:
|
438 |
+
x = self.downsample(x, Ph, Pw)
|
439 |
+
Ph, Pw = (Ph + 1) // 2, (Pw + 1) // 2
|
440 |
+
return x, Ph, Pw
|
441 |
|
442 |
def extra_repr(self) -> str:
|
443 |
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
|
|
484 |
def forward(self, x):
|
485 |
B, C, H, W = x.shape
|
486 |
# FIXME look at relaxing size constraints
|
487 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
488 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
489 |
+
x = self.proj(x)
|
490 |
+
Ph, Pw = x.shape[2:]
|
491 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
492 |
if self.norm is not None:
|
493 |
x = self.norm(x)
|
494 |
+
return x, Ph, Pw
|
495 |
|
496 |
def flops(self):
|
497 |
Ho, Wo = self.patches_resolution
|
|
|
604 |
return {'relative_position_bias_table'}
|
605 |
|
606 |
def forward_features(self, x, output_map=False):
|
607 |
+
x, Ph, Pw = self.patch_embed(x)
|
608 |
if self.ape:
|
609 |
x = x + self.absolute_pos_embed
|
610 |
x = self.pos_drop(x)
|
611 |
|
612 |
for layer in self.layers:
|
613 |
+
x, Ph, Pw = layer(x, Ph, Pw)
|
614 |
|
615 |
x_map = self.norm(x).transpose(1, 2) # B C L
|
616 |
x = self.avgpool(x_map) # B C 1
|
617 |
x = torch.flatten(x, 1)
|
618 |
|
619 |
if output_map:
|
620 |
+
return x, x_map, Ph, Pw
|
621 |
else:
|
622 |
return x
|
623 |
|
model/model.py
CHANGED
@@ -156,7 +156,7 @@ class UniCLModel(nn.Module):
|
|
156 |
def encode_image(self, image, norm=True, output_map=False):
|
157 |
x = self.image_encoder.forward_features(image, output_map=output_map)
|
158 |
if output_map:
|
159 |
-
x, x_map = x
|
160 |
|
161 |
x = x @ self.image_projection
|
162 |
|
@@ -169,7 +169,7 @@ class UniCLModel(nn.Module):
|
|
169 |
x_map = x_map / x_map.norm(dim=1, keepdim=True)
|
170 |
|
171 |
if output_map:
|
172 |
-
return x, x_map
|
173 |
else:
|
174 |
return x
|
175 |
|
|
|
156 |
def encode_image(self, image, norm=True, output_map=False):
|
157 |
x = self.image_encoder.forward_features(image, output_map=output_map)
|
158 |
if output_map:
|
159 |
+
x, x_map, H, W = x
|
160 |
|
161 |
x = x @ self.image_projection
|
162 |
|
|
|
169 |
x_map = x_map / x_map.norm(dim=1, keepdim=True)
|
170 |
|
171 |
if output_map:
|
172 |
+
return x, x_map, H, W
|
173 |
else:
|
174 |
return x
|
175 |
|
model/text_encoder/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/model/text_encoder/__pycache__/__init__.cpython-39.pyc and b/model/text_encoder/__pycache__/__init__.cpython-39.pyc differ
|
|
model/text_encoder/__pycache__/build.cpython-39.pyc
CHANGED
Binary files a/model/text_encoder/__pycache__/build.cpython-39.pyc and b/model/text_encoder/__pycache__/build.cpython-39.pyc differ
|
|
model/text_encoder/__pycache__/hf_model.cpython-39.pyc
CHANGED
Binary files a/model/text_encoder/__pycache__/hf_model.cpython-39.pyc and b/model/text_encoder/__pycache__/hf_model.cpython-39.pyc differ
|
|
model/text_encoder/__pycache__/registry.cpython-39.pyc
CHANGED
Binary files a/model/text_encoder/__pycache__/registry.cpython-39.pyc and b/model/text_encoder/__pycache__/registry.cpython-39.pyc differ
|
|
model/text_encoder/__pycache__/transformer.cpython-39.pyc
CHANGED
Binary files a/model/text_encoder/__pycache__/transformer.cpython-39.pyc and b/model/text_encoder/__pycache__/transformer.cpython-39.pyc differ
|
|