Spaces:
Running
Running
import gradio as gr | |
import torch | |
import timm | |
import torch.nn.functional as F | |
from timm.models import create_model | |
from timm.data import create_transform | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from typing import List, Tuple, Dict | |
from collections import OrderedDict | |
class AttentionExtractor: | |
def __init__(self, model: torch.nn.Module): | |
self.model = model | |
self.attention_maps = OrderedDict() | |
self._register_hooks() | |
def _register_hooks(self): | |
def hook_fn(module, input, output): | |
if isinstance(output, tuple): | |
self.attention_maps[module.full_name] = output[1] # attention_probs | |
else: | |
self.attention_maps[module.full_name] = output | |
for name, module in self.model.named_modules(): | |
if name.lower().endswith('.attn_drop'): | |
module.full_name = name | |
print('hooking', name) | |
module.register_forward_hook(hook_fn) | |
def get_attention_maps(self) -> OrderedDict: | |
return self.attention_maps | |
def get_attention_models() -> List[str]: | |
"""Get a list of timm models that have attention blocks.""" | |
all_models = timm.list_models() | |
attention_models = [model for model in all_models if 'vit' in model.lower()] # Focusing on ViT models for simplicity | |
return attention_models | |
def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]: | |
"""Load a model from timm and prepare it for attention extraction.""" | |
timm.layers.set_fused_attn(False) | |
model = create_model(model_name, pretrained=True) | |
model.eval() | |
extractor = AttentionExtractor(model) | |
return model, extractor | |
def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]: | |
"""Process the input image and get the attention maps.""" | |
# Get the correct transform for the model | |
config = model.pretrained_cfg | |
transform = create_transform( | |
input_size=config['input_size'], | |
crop_pct=config['crop_pct'], | |
mean=config['mean'], | |
std=config['std'], | |
interpolation=config['interpolation'], | |
is_training=False | |
) | |
# Preprocess the image | |
tensor = transform(image).unsqueeze(0) | |
# Forward pass | |
with torch.no_grad(): | |
_ = model(tensor) | |
# Extract attention maps | |
attention_maps = extractor.get_attention_maps() | |
return attention_maps | |
def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray: | |
# Ensure mask and image have the same shape | |
mask = mask[:, :, np.newaxis] | |
mask = np.repeat(mask, 3, axis=2) | |
# Convert color to numpy array | |
color = np.array(color) | |
# Apply mask | |
masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255 | |
return masked_image.astype(np.uint8) | |
def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image]: | |
"""Visualize attention maps for the given image and model.""" | |
model, extractor = load_model(model_name) | |
attention_maps = process_image(image, model, extractor) | |
# Convert PIL Image to numpy array | |
image_np = np.array(image) | |
# Create visualizations | |
visualizations = [] | |
for layer_name, attn_map in attention_maps.items(): | |
print(f"Attention map shape for {layer_name}: {attn_map.shape}") | |
# Remove the CLS token attention and average over heads | |
attn_map = attn_map[0, :, 0, 1:].mean(0) # Shape: (seq_len-1,) | |
# Reshape the attention map to 2D | |
num_patches = int(np.sqrt(attn_map.shape[0])) | |
attn_map = attn_map.reshape(num_patches, num_patches) | |
# Interpolate to match image size | |
attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0) | |
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False) | |
attn_map = attn_map.squeeze().cpu().numpy() | |
# Normalize attention map | |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) | |
# Create visualization | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) | |
# Original image | |
ax1.imshow(image_np) | |
ax1.set_title("Original Image") | |
ax1.axis('off') | |
# Attention map overlay | |
masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0)) # Red mask | |
ax2.imshow(masked_image) | |
ax2.set_title(f'Attention Map for {layer_name}') | |
ax2.axis('off') | |
plt.tight_layout() | |
# Convert plot to image | |
fig.canvas.draw() | |
vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
visualizations.append(vis_image) | |
plt.close(fig) | |
return visualizations | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=visualize_attention, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Dropdown(choices=get_attention_models(), label="Select Model") | |
], | |
outputs=gr.Gallery(label="Attention Maps"), | |
title="Attention Map Visualizer for timm Models", | |
description="Upload an image and select a timm model to visualize its attention maps." | |
) | |
iface.launch(debug=True) |