Spaces:
Build error
Build error
jwyang
commited on
Commit
•
ad7aaa6
1
Parent(s):
a574e10
add heatmap visualization
Browse files- app.py +16 -4
- model/image_encoder/swin_transformer.py +8 -4
- model/model.py +15 -4
app.py
CHANGED
@@ -118,11 +118,20 @@ def recognize_image(image, texts):
|
|
118 |
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
119 |
|
120 |
# compute output
|
121 |
-
feat_img = model.encode_image(img_t.unsqueeze(0))
|
122 |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
123 |
prediction = output.softmax(-1).flatten()
|
124 |
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
|
128 |
image = gr.inputs.Image()
|
@@ -132,8 +141,11 @@ gr.Interface(
|
|
132 |
description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
|
133 |
fn=recognize_image,
|
134 |
inputs=["image", "text"],
|
135 |
-
outputs=[
|
136 |
-
|
|
|
|
|
|
|
137 |
],
|
138 |
examples=[
|
139 |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
|
|
118 |
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
119 |
|
120 |
# compute output
|
121 |
+
feat_img, feat_map = model.encode_image(img_t.unsqueeze(0), output_map=True)
|
122 |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
123 |
prediction = output.softmax(-1).flatten()
|
124 |
|
125 |
+
# generate feat map given the top matched texts
|
126 |
+
output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
|
127 |
+
output_map = output_map.view(1, 1, 7, 7)
|
128 |
+
|
129 |
+
output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
|
130 |
+
output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
|
131 |
+
output_map = (output_map - output_map.min()) / (output_map.max() - output_map.min())
|
132 |
+
heatmap = show_cam_on_image(img_d, output_map, use_rgb=True)
|
133 |
+
|
134 |
+
return Image.fromarray(heatmap), {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}
|
135 |
|
136 |
|
137 |
image = gr.inputs.Image()
|
|
|
141 |
description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
|
142 |
fn=recognize_image,
|
143 |
inputs=["image", "text"],
|
144 |
+
outputs=[
|
145 |
+
gr.outputs.Image(
|
146 |
+
type="pil",
|
147 |
+
label="zero-shot heat map"),
|
148 |
+
label,
|
149 |
],
|
150 |
examples=[
|
151 |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
model/image_encoder/swin_transformer.py
CHANGED
@@ -557,7 +557,7 @@ class SwinTransformer(nn.Module):
|
|
557 |
def no_weight_decay_keywords(self):
|
558 |
return {'relative_position_bias_table'}
|
559 |
|
560 |
-
def forward_features(self, x):
|
561 |
x = self.patch_embed(x)
|
562 |
if self.ape:
|
563 |
x = x + self.absolute_pos_embed
|
@@ -566,10 +566,14 @@ class SwinTransformer(nn.Module):
|
|
566 |
for layer in self.layers:
|
567 |
x = layer(x)
|
568 |
|
569 |
-
|
570 |
-
x = self.avgpool(
|
571 |
x = torch.flatten(x, 1)
|
572 |
-
|
|
|
|
|
|
|
|
|
573 |
|
574 |
def forward(self, x):
|
575 |
x = self.forward_features(x)
|
|
|
557 |
def no_weight_decay_keywords(self):
|
558 |
return {'relative_position_bias_table'}
|
559 |
|
560 |
+
def forward_features(self, x, output_map=False):
|
561 |
x = self.patch_embed(x)
|
562 |
if self.ape:
|
563 |
x = x + self.absolute_pos_embed
|
|
|
566 |
for layer in self.layers:
|
567 |
x = layer(x)
|
568 |
|
569 |
+
x_map = self.norm(x).transpose(1, 2) # B C L
|
570 |
+
x = self.avgpool(x_map) # B C 1
|
571 |
x = torch.flatten(x, 1)
|
572 |
+
|
573 |
+
if output_map:
|
574 |
+
return x, x_map
|
575 |
+
else:
|
576 |
+
return x
|
577 |
|
578 |
def forward(self, x):
|
579 |
x = self.forward_features(x)
|
model/model.py
CHANGED
@@ -153,14 +153,25 @@ class UniCLModel(nn.Module):
|
|
153 |
imnet_text_embeddings = torch.stack(clss_embeddings, dim=0)
|
154 |
return imnet_text_embeddings
|
155 |
|
156 |
-
def encode_image(self, image, norm=True):
|
157 |
-
x = self.image_encoder.forward_features(image)
|
|
|
|
|
|
|
158 |
x = x @ self.image_projection
|
159 |
|
|
|
|
|
|
|
160 |
if norm:
|
161 |
x = x / x.norm(dim=-1, keepdim=True)
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
def encode_text(self, text, norm=True):
|
166 |
x = self.text_encoder(**text)
|
|
|
153 |
imnet_text_embeddings = torch.stack(clss_embeddings, dim=0)
|
154 |
return imnet_text_embeddings
|
155 |
|
156 |
+
def encode_image(self, image, norm=True, output_map=False):
|
157 |
+
x = self.image_encoder.forward_features(image, output_map=output_map)
|
158 |
+
if output_map:
|
159 |
+
x, x_map = x
|
160 |
+
|
161 |
x = x @ self.image_projection
|
162 |
|
163 |
+
if output_map:
|
164 |
+
x_map = self.image_projection.unsqueeze(0).transpose(1, 2) @ x_map
|
165 |
+
|
166 |
if norm:
|
167 |
x = x / x.norm(dim=-1, keepdim=True)
|
168 |
+
if output_map:
|
169 |
+
x_map = x_map / x_map.norm(dim=1, keepdim=True)
|
170 |
+
|
171 |
+
if output_map:
|
172 |
+
return x, x_map
|
173 |
+
else:
|
174 |
+
return x
|
175 |
|
176 |
def encode_text(self, text, norm=True):
|
177 |
x = self.text_encoder(**text)
|