File size: 8,667 Bytes
a5604b3
77f3515
 
a5604b3
9571b87
 
 
77f3515
9571b87
77f3515
 
9571b87
 
 
 
 
 
7163838
 
c139d3f
9571b87
 
cc361dd
9571b87
 
 
980c76b
9571b87
980c76b
9571b87
 
1498a37
cc361dd
 
 
 
 
9571b87
 
 
 
 
 
 
 
 
 
 
 
980c76b
 
77f3515
9571b87
77f3515
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6458094
980c76b
 
6458094
 
 
 
 
 
 
 
 
 
 
 
 
a5604b3
6458094
 
 
980c76b
6458094
 
 
 
 
 
 
 
dfc8148
6458094
 
 
942ca77
6458094
 
 
 
 
 
 
9571b87
 
7163838
6458094
 
9571b87
 
 
 
 
6458094
9571b87
 
3037e32
6458094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9571b87
6458094
9571b87
6458094
9571b87
dfc8148
9571b87
3037e32
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac69117
 
 
9571b87
 
 
3037e32
 
980c76b
6458094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac69117
 
 
6458094
 
 
 
9571b87
 
 
 
 
6458094
 
 
 
 
 
 
 
 
 
 
9571b87
6458094
9571b87
 
 
980c76b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import spaces
from typing import List, Tuple, Dict
from collections import OrderedDict

import gradio as gr
import torch
import torch.nn.functional as F
import timm
from timm.data import create_transform
from timm.models import create_model
from timm.utils import AttentionExtract
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def get_attention_models() -> List[str]:
    """Get a list of timm models that have attention blocks."""
    all_models = timm.list_pretrained()
    # FIXME Focusing on ViT models for initial impl
    attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])]
    return attention_models

def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]:
    """Load a model from timm and prepare it for attention extraction."""
    timm.layers.set_fused_attn(False)
    model = create_model(model_name, pretrained=True)
    model = model.cuda()  # Move the model to CUDA
    model.eval()
    extractor = AttentionExtract(model, method='fx')
    return model, extractor

@spaces.GPU
def process_image(
        image: Image.Image,
        model: torch.nn.Module,
        extractor: AttentionExtract
) -> Dict[str, torch.Tensor]:
    """Process the input image and get the attention maps."""
    # Get the correct transform for the model
    config = model.pretrained_cfg
    transform = create_transform(
        input_size=config['input_size'],
        crop_pct=config['crop_pct'],
        mean=config['mean'],
        std=config['std'],
        interpolation=config['interpolation'],
        is_training=False
    )
    
    # Preprocess the image and move to CUDA
    tensor = transform(image).unsqueeze(0).cuda()
       
    # Extract attention maps
    attention_maps = extractor(tensor)
    
    return attention_maps

def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray:
    # Ensure mask and image have the same shape
    mask = mask[:, :, np.newaxis]
    mask = np.repeat(mask, 3, axis=2)
    
    # Convert color to numpy array
    color = np.array(color)
    
    # Apply mask
    masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
    return masked_image.astype(np.uint8)

def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
    device = attentions[0].device
    result = torch.eye(attentions[0].size(-1)).to(device)
    with torch.no_grad():
        for attention in attentions:
            if head_fusion.startswith('mean'):
                attention_heads_fused = attention.mean(dim=0)
            elif head_fusion == "max":
                attention_heads_fused = attention.amax(dim=0)
            elif head_fusion == "min":
                attention_heads_fused = attention.amin(dim=0)
            else:
                raise ValueError("Attention head fusion type Not supported")

            # Discard the lowest attentions, but don't discard the prefix tokens
            flat = attention_heads_fused.view(-1)
            _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
            indices = indices[indices >= num_prefix_tokens]
            flat[indices] = 0

            I = torch.eye(attention_heads_fused.size(-1)).to(device)
            a = (attention_heads_fused + 1.0 * I) / 2
            a = a / a.sum(dim=-1)
            result = torch.matmul(a, result)
    
    # Look at the total attention between the prefix tokens (usually class tokens)
    # and the image patches    
    mask = result[0, num_prefix_tokens:]
    width = int(mask.size(-1) ** 0.5)
    mask = mask.reshape(width, width).cpu().numpy()
    mask = mask / np.max(mask)
    return mask

@spaces.GPU
def visualize_attention(
        image: Image.Image,
        model_name: str,
        head_fusion: str,
        discard_ratio: float,
) -> Tuple[List[Image.Image], Image.Image]:
    """Visualize attention maps and rollout for the given image and model."""
    model, extractor = load_model(model_name)
    attention_maps = process_image(image, model, extractor)
    
    num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1)  # Default to 1 class token if not specified

    # Convert PIL Image to numpy array
    image_np = np.array(image)

    # Create visualizations
    visualizations = []
    attentions_for_rollout = []
    for layer_name, attn_map in attention_maps.items():
        print(f"Attention map shape for {layer_name}: {attn_map.shape}")
        attn_map = attn_map[0].detach()  # Remove batch dimension and detach

        attentions_for_rollout.append(attn_map)

        attn_map = attn_map[:, :, num_prefix_tokens:]  # Remove prefix tokens for visualization

        if head_fusion == 'mean_std':                
            attn_map = attn_map.mean(0) / attn_map.std(0)
        elif head_fusion == 'mean':
            attn_map = attn_map.mean(0)
        elif head_fusion == 'max':
            attn_map = attn_map.amax(0)
        elif head_fusion == 'min':
            attn_map = attn_map.amin(0)
        else:
            raise ValueError(f"Invalid head fusion method: {head_fusion}")

        # Use the first token's attention (usually the class token)
        attn_map = attn_map[0]

        # Reshape the attention map to 2D
        num_patches = int(attn_map.shape[0] ** 0.5)
        attn_map = attn_map.reshape(num_patches, num_patches)

        # Interpolate to match image size
        attn_map = attn_map.unsqueeze(0).unsqueeze(0)
        attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
        attn_map = attn_map.squeeze().cpu().detach().numpy()  # Move to CPU, detach, and convert to numpy

        # Normalize attention map
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())

        # Create visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

        # Original image
        ax1.imshow(image_np)
        ax1.set_title("Original Image")
        ax1.axis('off')

        # Attention map overlay
        masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0))  # Red mask
        ax2.imshow(masked_image)
        ax2.set_title(f'Attention Map for {layer_name}')
        ax2.axis('off')

        plt.tight_layout()

        # Convert plot to image
        fig.canvas.draw()
        data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        vis_image = Image.fromarray(data)
        visualizations.append(vis_image)
        plt.close(fig)

    # Ensure tensors are on CPU and detached before converting to numpy
    attentions_for_rollout = [attn.cpu().detach() for attn in attentions_for_rollout]

    # Calculate rollout
    rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)

    # Create rollout visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

    # Original image
    ax1.imshow(image_np)
    ax1.set_title("Original Image")
    ax1.axis('off')

    # Rollout overlay
    rollout_mask_pil = Image.fromarray((rollout_mask * 255).astype(np.uint8))
    rollout_mask_resized = np.array(rollout_mask_pil.resize((image_np.shape[1], image_np.shape[0]), Image.BICUBIC)) / 255.0
    masked_image = apply_mask(image_np, rollout_mask_resized, color=(1, 0, 0))  # Red mask
    ax2.imshow(masked_image)
    ax2.set_title('Attention Rollout')
    ax2.axis('off')

    plt.tight_layout()

    # Convert plot to image
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    rollout_image = Image.fromarray(data)
    plt.close(fig)

    return visualizations, rollout_image

# Create Gradio interface
iface = gr.Interface(
    fn=visualize_attention,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Dropdown(choices=get_attention_models(), label="Select Model"),
        gr.Dropdown(
            choices=['mean_std', 'mean', 'max', 'min'],
            label="Head Fusion Method",
            value='mean'  # Default value
        ),
       gr.Slider(0, 1, 0.9, label="Discard Ratio", info="Ratio of lowest attentions to discard")
    ],
    outputs=[
        gr.Gallery(label="Attention Maps"),
        gr.Image(label="Attention Rollout")
    ],
    title="Attention Map Visualizer for timm Models",
    description="Upload an image and select a timm model to visualize its attention maps."
)

# Launch the interface
iface.launch()