File size: 5,373 Bytes
9571b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
import timm
import torch.nn.functional as F
from timm.models import create_model
from timm.data import create_transform
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict
from collections import OrderedDict

class AttentionExtractor:
    def __init__(self, model: torch.nn.Module):
        self.model = model
        self.attention_maps = OrderedDict()
        self._register_hooks()

    def _register_hooks(self):
        def hook_fn(module, input, output):
            if isinstance(output, tuple):
                self.attention_maps[module.full_name] = output[1]  # attention_probs
            else:
                self.attention_maps[module.full_name] = output

        for name, module in self.model.named_modules():
            if name.lower().endswith('.attn_drop'):
                module.full_name = name
                print('hooking', name)
                module.register_forward_hook(hook_fn)

    def get_attention_maps(self) -> OrderedDict:
        return self.attention_maps

def get_attention_models() -> List[str]:
    """Get a list of timm models that have attention blocks."""
    all_models = timm.list_models()
    attention_models = [model for model in all_models if 'vit' in model.lower()]  # Focusing on ViT models for simplicity
    return attention_models

def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
    """Load a model from timm and prepare it for attention extraction."""
    timm.layers.set_fused_attn(False)
    model = create_model(model_name, pretrained=True)
    model.eval()
    extractor = AttentionExtractor(model)
    return model, extractor

def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]:
    """Process the input image and get the attention maps."""
    # Get the correct transform for the model
    config = model.pretrained_cfg
    transform = create_transform(
        input_size=config['input_size'],
        crop_pct=config['crop_pct'],
        mean=config['mean'],
        std=config['std'],
        interpolation=config['interpolation'],
        is_training=False
    )
    
    
    # Preprocess the image
    tensor = transform(image).unsqueeze(0)
    
    # Forward pass
    with torch.no_grad():
        _ = model(tensor)
    
    # Extract attention maps
    attention_maps = extractor.get_attention_maps()
    
    return attention_maps

def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray:
    # Ensure mask and image have the same shape
    mask = mask[:, :, np.newaxis]
    mask = np.repeat(mask, 3, axis=2)
    
    # Convert color to numpy array
    color = np.array(color)
    
    # Apply mask
    masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
    return masked_image.astype(np.uint8)

def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image]:
    """Visualize attention maps for the given image and model."""
    model, extractor = load_model(model_name)
    attention_maps = process_image(image, model, extractor)

    # Convert PIL Image to numpy array
    image_np = np.array(image)

    # Create visualizations
    visualizations = []
    for layer_name, attn_map in attention_maps.items():
        print(f"Attention map shape for {layer_name}: {attn_map.shape}")
        
        # Remove the CLS token attention and average over heads
        attn_map = attn_map[0, :, 0, 1:].mean(0)  # Shape: (seq_len-1,)
        
        # Reshape the attention map to 2D
        num_patches = int(np.sqrt(attn_map.shape[0]))
        attn_map = attn_map.reshape(num_patches, num_patches)
        
        # Interpolate to match image size
        attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
        attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
        attn_map = attn_map.squeeze().cpu().numpy()

        # Normalize attention map
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())

        # Create visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

        # Original image
        ax1.imshow(image_np)
        ax1.set_title("Original Image")
        ax1.axis('off')

        # Attention map overlay
        masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0))  # Red mask
        ax2.imshow(masked_image)
        ax2.set_title(f'Attention Map for {layer_name}')
        ax2.axis('off')

        plt.tight_layout()

        # Convert plot to image
        fig.canvas.draw()
        vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
        visualizations.append(vis_image)
        plt.close(fig)

    return visualizations

# Create Gradio interface
iface = gr.Interface(
    fn=visualize_attention,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Dropdown(choices=get_attention_models(), label="Select Model")
    ],
    outputs=gr.Gallery(label="Attention Maps"),
    title="Attention Map Visualizer for timm Models",
    description="Upload an image and select a timm model to visualize its attention maps."
)

iface.launch(debug=True)