rwightman HF staff commited on
Commit
9571b87
1 Parent(s): e2a9fa8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import timm
4
+ import torch.nn.functional as F
5
+ from timm.models import create_model
6
+ from timm.data import create_transform
7
+ from PIL import Image
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from typing import List, Tuple, Dict
11
+ from collections import OrderedDict
12
+
13
+ class AttentionExtractor:
14
+ def __init__(self, model: torch.nn.Module):
15
+ self.model = model
16
+ self.attention_maps = OrderedDict()
17
+ self._register_hooks()
18
+
19
+ def _register_hooks(self):
20
+ def hook_fn(module, input, output):
21
+ if isinstance(output, tuple):
22
+ self.attention_maps[module.full_name] = output[1] # attention_probs
23
+ else:
24
+ self.attention_maps[module.full_name] = output
25
+
26
+ for name, module in self.model.named_modules():
27
+ if name.lower().endswith('.attn_drop'):
28
+ module.full_name = name
29
+ print('hooking', name)
30
+ module.register_forward_hook(hook_fn)
31
+
32
+ def get_attention_maps(self) -> OrderedDict:
33
+ return self.attention_maps
34
+
35
+ def get_attention_models() -> List[str]:
36
+ """Get a list of timm models that have attention blocks."""
37
+ all_models = timm.list_models()
38
+ attention_models = [model for model in all_models if 'vit' in model.lower()] # Focusing on ViT models for simplicity
39
+ return attention_models
40
+
41
+ def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
42
+ """Load a model from timm and prepare it for attention extraction."""
43
+ timm.layers.set_fused_attn(False)
44
+ model = create_model(model_name, pretrained=True)
45
+ model.eval()
46
+ extractor = AttentionExtractor(model)
47
+ return model, extractor
48
+
49
+ def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]:
50
+ """Process the input image and get the attention maps."""
51
+ # Get the correct transform for the model
52
+ config = model.pretrained_cfg
53
+ transform = create_transform(
54
+ input_size=config['input_size'],
55
+ crop_pct=config['crop_pct'],
56
+ mean=config['mean'],
57
+ std=config['std'],
58
+ interpolation=config['interpolation'],
59
+ is_training=False
60
+ )
61
+
62
+
63
+ # Preprocess the image
64
+ tensor = transform(image).unsqueeze(0)
65
+
66
+ # Forward pass
67
+ with torch.no_grad():
68
+ _ = model(tensor)
69
+
70
+ # Extract attention maps
71
+ attention_maps = extractor.get_attention_maps()
72
+
73
+ return attention_maps
74
+
75
+ def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray:
76
+ # Ensure mask and image have the same shape
77
+ mask = mask[:, :, np.newaxis]
78
+ mask = np.repeat(mask, 3, axis=2)
79
+
80
+ # Convert color to numpy array
81
+ color = np.array(color)
82
+
83
+ # Apply mask
84
+ masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
85
+ return masked_image.astype(np.uint8)
86
+
87
+ def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image]:
88
+ """Visualize attention maps for the given image and model."""
89
+ model, extractor = load_model(model_name)
90
+ attention_maps = process_image(image, model, extractor)
91
+
92
+ # Convert PIL Image to numpy array
93
+ image_np = np.array(image)
94
+
95
+ # Create visualizations
96
+ visualizations = []
97
+ for layer_name, attn_map in attention_maps.items():
98
+ print(f"Attention map shape for {layer_name}: {attn_map.shape}")
99
+
100
+ # Remove the CLS token attention and average over heads
101
+ attn_map = attn_map[0, :, 0, 1:].mean(0) # Shape: (seq_len-1,)
102
+
103
+ # Reshape the attention map to 2D
104
+ num_patches = int(np.sqrt(attn_map.shape[0]))
105
+ attn_map = attn_map.reshape(num_patches, num_patches)
106
+
107
+ # Interpolate to match image size
108
+ attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
109
+ attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
110
+ attn_map = attn_map.squeeze().cpu().numpy()
111
+
112
+ # Normalize attention map
113
+ attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
114
+
115
+ # Create visualization
116
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
117
+
118
+ # Original image
119
+ ax1.imshow(image_np)
120
+ ax1.set_title("Original Image")
121
+ ax1.axis('off')
122
+
123
+ # Attention map overlay
124
+ masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0)) # Red mask
125
+ ax2.imshow(masked_image)
126
+ ax2.set_title(f'Attention Map for {layer_name}')
127
+ ax2.axis('off')
128
+
129
+ plt.tight_layout()
130
+
131
+ # Convert plot to image
132
+ fig.canvas.draw()
133
+ vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
134
+ visualizations.append(vis_image)
135
+ plt.close(fig)
136
+
137
+ return visualizations
138
+
139
+ # Create Gradio interface
140
+ iface = gr.Interface(
141
+ fn=visualize_attention,
142
+ inputs=[
143
+ gr.Image(type="pil", label="Input Image"),
144
+ gr.Dropdown(choices=get_attention_models(), label="Select Model")
145
+ ],
146
+ outputs=gr.Gallery(label="Attention Maps"),
147
+ title="Attention Map Visualizer for timm Models",
148
+ description="Upload an image and select a timm model to visualize its attention maps."
149
+ )
150
+
151
+ iface.launch(debug=True)