File size: 8,589 Bytes
77f3515
 
 
9571b87
 
 
77f3515
9571b87
77f3515
 
9571b87
 
 
 
 
 
 
7163838
 
c139d3f
9571b87
 
cc361dd
9571b87
 
 
 
77f3515
9571b87
 
cc361dd
 
 
 
 
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
77f3515
9571b87
77f3515
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6458094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9571b87
 
7163838
6458094
 
 
9571b87
 
 
 
 
6458094
9571b87
 
6458094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9571b87
6458094
9571b87
6458094
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6458094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9571b87
 
 
 
 
 
6458094
 
 
 
 
 
 
 
 
 
 
9571b87
6458094
9571b87
 
 
6458094
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from typing import List, Tuple, Dict
from collections import OrderedDict

import gradio as gr
import torch
import torch.nn.functional as F
import timm
from timm.data import create_transform
from timm.models import create_model
from timm.utils import AttentionExtract
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt


def get_attention_models() -> List[str]:
    """Get a list of timm models that have attention blocks."""
    all_models = timm.list_pretrained()
    # FIXME Focusing on ViT models for initial impl
    attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])]
    return attention_models

def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]:
    """Load a model from timm and prepare it for attention extraction."""
    timm.layers.set_fused_attn(False)
    model = create_model(model_name, pretrained=True)
    model.eval()
    extractor = AttentionExtract(model, method='fx')  # can use 'hooks', can also allow specifying matching names for attention nodes or modules...
    return model, extractor

def process_image(
        image: Image.Image,
        model: torch.nn.Module,
        extractor: AttentionExtract
) -> Dict[str, torch.Tensor]:
    """Process the input image and get the attention maps."""
    # Get the correct transform for the model
    config = model.pretrained_cfg
    transform = create_transform(
        input_size=config['input_size'],
        crop_pct=config['crop_pct'],
        mean=config['mean'],
        std=config['std'],
        interpolation=config['interpolation'],
        is_training=False
    )
    
    # Preprocess the image
    tensor = transform(image).unsqueeze(0)
       
    # Extract attention maps
    attention_maps = extractor(tensor)
    
    return attention_maps

def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray:
    # Ensure mask and image have the same shape
    mask = mask[:, :, np.newaxis]
    mask = np.repeat(mask, 3, axis=2)
    
    # Convert color to numpy array
    color = np.array(color)
    
    # Apply mask
    masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
    return masked_image.astype(np.uint8)


def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
    # based on https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention in attentions:
            if head_fusion.startswith('mean'):
                # mean_std fusion doesn't appear to make sense with rollout
                attention_heads_fused = attention.mean(dim=0)
            elif head_fusion == "max":
                attention_heads_fused = attention.amax(dim=0)
            elif head_fusion == "min":
                attention_heads_fused = attention.amin(dim=0)
            else:
                raise ValueError("Attention head fusion type Not supported")

            # Discard the lowest attentions, but don't discard the prefix tokens
            flat = attention_heads_fused.view(-1)
            _, indices = flat.topk(int(flat.size(-1 )* discard_ratio), -1, False)
            indices = indices[indices >= num_prefix_tokens]
            flat[indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0 * I) / 2
            a = a / a.sum(dim=-1)
            result = torch.matmul(a, result)
    
    # Look at the total attention between the prefix tokens (usually class tokens)
    # and the image patches    
    # FIXME this is token 0 vs non-prefix right now, need to cover other cases (> 1 prefix, no prefix, etc)
    mask = result[0, num_prefix_tokens:]
    width = int(mask.size(-1) ** 0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask


def visualize_attention(
        image: Image.Image,
        model_name: str,
        head_fusion: str,
        discard_ratio: float,
) -> Tuple[List[Image.Image], Image.Image]:
    """Visualize attention maps and rollout for the given image and model."""
    model, extractor = load_model(model_name)
    attention_maps = process_image(image, model, extractor)
    
    # FIXME handle wider range of models that may not have num_prefix_tokens attr
    num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1)  # Default to 1 class token if not specified

    # Convert PIL Image to numpy array
    image_np = np.array(image)

    # Create visualizations
    visualizations = []
    attentions_for_rollout = []
    for layer_name, attn_map in attention_maps.items():
        print(f"Attention map shape for {layer_name}: {attn_map.shape}")
        attn_map = attn_map[0]  # Remove batch dimension

        attentions_for_rollout.append(attn_map)

        attn_map = attn_map[:, :, num_prefix_tokens:]  # Remove prefix tokens for visualization

        if head_fusion == 'mean_std':                
            attn_map = attn_map.mean(0) / attn_map.std(0)
        elif head_fusion == 'mean':
            attn_map = attn_map.mean(0)
        elif head_fusion == 'max':
            attn_map = attn_map.amax(0)
        elif head_fusion == 'min':
            attn_map = attn_map.amin(0)
        else:
            raise ValueError(f"Invalid head fusion method: {head_fusion}")

        # Use the first token's attention (usually the class token)
        # FIXME handle different prefix token scenarios
        attn_map = attn_map[0]

        # Reshape the attention map to 2D
        num_patches = int(attn_map.shape[0] ** 0.5)
        attn_map = attn_map.reshape(num_patches, num_patches)

        # Interpolate to match image size
        attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
        attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
        attn_map = attn_map.squeeze().cpu().numpy()

        # Normalize attention map
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())

        # Create visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

        # Original image
        ax1.imshow(image_np)
        ax1.set_title("Original Image")
        ax1.axis('off')

        # Attention map overlay
        masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0))  # Red mask
        ax2.imshow(masked_image)
        ax2.set_title(f'Attention Map for {layer_name}')
        ax2.axis('off')

        plt.tight_layout()

        # Convert plot to image
        fig.canvas.draw()
        vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
        visualizations.append(vis_image)
        plt.close(fig)

    # Calculate rollout
    rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)

    # Create rollout visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

    # Original image
    ax1.imshow(image_np)
    ax1.set_title("Original Image")
    ax1.axis('off')

    # Rollout overlay
    rollout_mask_pil = Image.fromarray((rollout_mask * 255).astype(np.uint8))
    rollout_mask_resized = np.array(rollout_mask_pil.resize((image_np.shape[1], image_np.shape[0]), Image.BICUBIC)) / 255.0
    masked_image = apply_mask(image_np, rollout_mask_resized, color=(1, 0, 0))  # Red mask
    ax2.imshow(masked_image)
    ax2.set_title('Attention Rollout')
    ax2.axis('off')

    plt.tight_layout()

    # Convert plot to image
    fig.canvas.draw()
    rollout_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
    plt.close(fig)

    return visualizations, rollout_image


# Create Gradio interface
iface = gr.Interface(
    fn=visualize_attention,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Dropdown(choices=get_attention_models(), label="Select Model"),
        gr.Dropdown(
            choices=['mean_std', 'mean', 'max', 'min'],
            label="Head Fusion Method",
            value='mean'  # Default value
        ),
       gr.Slider(0, 1, 0.9, label="Discard Ratio", info="Ratio of lowest attentions to discard")
    ],
    outputs=[
        gr.Gallery(label="Attention Maps"),
        gr.Image(label="Attention Rollout")
    ],
    title="Attention Map Visualizer for timm Models",
    description="Upload an image and select a timm model to visualize its attention maps."
)

iface.launch()