Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -24,6 +24,7 @@ class AttentionExtractor:
|
|
24 |
self.attention_maps[module.full_name] = output
|
25 |
|
26 |
for name, module in self.model.named_modules():
|
|
|
27 |
if name.lower().endswith('.attn_drop'):
|
28 |
module.full_name = name
|
29 |
print('hooking', name)
|
@@ -34,8 +35,9 @@ class AttentionExtractor:
|
|
34 |
|
35 |
def get_attention_models() -> List[str]:
|
36 |
"""Get a list of timm models that have attention blocks."""
|
37 |
-
all_models = timm.
|
38 |
-
|
|
|
39 |
return attention_models
|
40 |
|
41 |
def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
|
@@ -88,7 +90,8 @@ def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image
|
|
88 |
"""Visualize attention maps for the given image and model."""
|
89 |
model, extractor = load_model(model_name)
|
90 |
attention_maps = process_image(image, model, extractor)
|
91 |
-
|
|
|
92 |
# Convert PIL Image to numpy array
|
93 |
image_np = np.array(image)
|
94 |
|
@@ -97,8 +100,8 @@ def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image
|
|
97 |
for layer_name, attn_map in attention_maps.items():
|
98 |
print(f"Attention map shape for {layer_name}: {attn_map.shape}")
|
99 |
|
100 |
-
# Remove the CLS token attention and average over heads
|
101 |
-
attn_map = attn_map[0, :, 0,
|
102 |
|
103 |
# Reshape the attention map to 2D
|
104 |
num_patches = int(np.sqrt(attn_map.shape[0]))
|
|
|
24 |
self.attention_maps[module.full_name] = output
|
25 |
|
26 |
for name, module in self.model.named_modules():
|
27 |
+
# FIXME need to make more generic outside of vit
|
28 |
if name.lower().endswith('.attn_drop'):
|
29 |
module.full_name = name
|
30 |
print('hooking', name)
|
|
|
35 |
|
36 |
def get_attention_models() -> List[str]:
|
37 |
"""Get a list of timm models that have attention blocks."""
|
38 |
+
all_models = timm.list_pretrained()
|
39 |
+
# FIXME Focusing on ViT models for initial impl
|
40 |
+
attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')]
|
41 |
return attention_models
|
42 |
|
43 |
def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
|
|
|
90 |
"""Visualize attention maps for the given image and model."""
|
91 |
model, extractor = load_model(model_name)
|
92 |
attention_maps = process_image(image, model, extractor)
|
93 |
+
num_prefix_tokens = getattr(model, 'num_prefix_tokens', 0)
|
94 |
+
|
95 |
# Convert PIL Image to numpy array
|
96 |
image_np = np.array(image)
|
97 |
|
|
|
100 |
for layer_name, attn_map in attention_maps.items():
|
101 |
print(f"Attention map shape for {layer_name}: {attn_map.shape}")
|
102 |
|
103 |
+
# Remove the CLS token attention and average over heads
|
104 |
+
attn_map = attn_map[0, :, 0, num_prefix_tokens:].mean(0) # Shape: (seq_len-1,)
|
105 |
|
106 |
# Reshape the attention map to 2D
|
107 |
num_patches = int(np.sqrt(attn_map.shape[0]))
|