rwightman's picture
rwightman HF staff
Create app.py
9571b87 verified
raw
history blame
5.37 kB
import gradio as gr
import torch
import timm
import torch.nn.functional as F
from timm.models import create_model
from timm.data import create_transform
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict
from collections import OrderedDict
class AttentionExtractor:
def __init__(self, model: torch.nn.Module):
self.model = model
self.attention_maps = OrderedDict()
self._register_hooks()
def _register_hooks(self):
def hook_fn(module, input, output):
if isinstance(output, tuple):
self.attention_maps[module.full_name] = output[1] # attention_probs
else:
self.attention_maps[module.full_name] = output
for name, module in self.model.named_modules():
if name.lower().endswith('.attn_drop'):
module.full_name = name
print('hooking', name)
module.register_forward_hook(hook_fn)
def get_attention_maps(self) -> OrderedDict:
return self.attention_maps
def get_attention_models() -> List[str]:
"""Get a list of timm models that have attention blocks."""
all_models = timm.list_models()
attention_models = [model for model in all_models if 'vit' in model.lower()] # Focusing on ViT models for simplicity
return attention_models
def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
"""Load a model from timm and prepare it for attention extraction."""
timm.layers.set_fused_attn(False)
model = create_model(model_name, pretrained=True)
model.eval()
extractor = AttentionExtractor(model)
return model, extractor
def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]:
"""Process the input image and get the attention maps."""
# Get the correct transform for the model
config = model.pretrained_cfg
transform = create_transform(
input_size=config['input_size'],
crop_pct=config['crop_pct'],
mean=config['mean'],
std=config['std'],
interpolation=config['interpolation'],
is_training=False
)
# Preprocess the image
tensor = transform(image).unsqueeze(0)
# Forward pass
with torch.no_grad():
_ = model(tensor)
# Extract attention maps
attention_maps = extractor.get_attention_maps()
return attention_maps
def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray:
# Ensure mask and image have the same shape
mask = mask[:, :, np.newaxis]
mask = np.repeat(mask, 3, axis=2)
# Convert color to numpy array
color = np.array(color)
# Apply mask
masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
return masked_image.astype(np.uint8)
def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image]:
"""Visualize attention maps for the given image and model."""
model, extractor = load_model(model_name)
attention_maps = process_image(image, model, extractor)
# Convert PIL Image to numpy array
image_np = np.array(image)
# Create visualizations
visualizations = []
for layer_name, attn_map in attention_maps.items():
print(f"Attention map shape for {layer_name}: {attn_map.shape}")
# Remove the CLS token attention and average over heads
attn_map = attn_map[0, :, 0, 1:].mean(0) # Shape: (seq_len-1,)
# Reshape the attention map to 2D
num_patches = int(np.sqrt(attn_map.shape[0]))
attn_map = attn_map.reshape(num_patches, num_patches)
# Interpolate to match image size
attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
attn_map = attn_map.squeeze().cpu().numpy()
# Normalize attention map
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
# Original image
ax1.imshow(image_np)
ax1.set_title("Original Image")
ax1.axis('off')
# Attention map overlay
masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0)) # Red mask
ax2.imshow(masked_image)
ax2.set_title(f'Attention Map for {layer_name}')
ax2.axis('off')
plt.tight_layout()
# Convert plot to image
fig.canvas.draw()
vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
visualizations.append(vis_image)
plt.close(fig)
return visualizations
# Create Gradio interface
iface = gr.Interface(
fn=visualize_attention,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(choices=get_attention_models(), label="Select Model")
],
outputs=gr.Gallery(label="Attention Maps"),
title="Attention Map Visualizer for timm Models",
description="Upload an image and select a timm model to visualize its attention maps."
)
iface.launch(debug=True)