import gradio as gr import torch import timm import torch.nn.functional as F from timm.models import create_model from timm.data import create_transform from PIL import Image import numpy as np import matplotlib.pyplot as plt from typing import List, Tuple, Dict from collections import OrderedDict class AttentionExtractor: def __init__(self, model: torch.nn.Module): self.model = model self.attention_maps = OrderedDict() self._register_hooks() def _register_hooks(self): def hook_fn(module, input, output): if isinstance(output, tuple): self.attention_maps[module.full_name] = output[1] # attention_probs else: self.attention_maps[module.full_name] = output for name, module in self.model.named_modules(): # FIXME need to make more generic outside of vit if name.lower().endswith('.attn_drop'): module.full_name = name print('hooking', name) module.register_forward_hook(hook_fn) def get_attention_maps(self) -> OrderedDict: return self.attention_maps def get_attention_models() -> List[str]: """Get a list of timm models that have attention blocks.""" all_models = timm.list_pretrained() # FIXME Focusing on ViT models for initial impl attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])] return attention_models def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]: """Load a model from timm and prepare it for attention extraction.""" timm.layers.set_fused_attn(False) model = create_model(model_name, pretrained=True) model.eval() extractor = AttentionExtractor(model) return model, extractor def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]: """Process the input image and get the attention maps.""" # Get the correct transform for the model config = model.pretrained_cfg transform = create_transform( input_size=config['input_size'], crop_pct=config['crop_pct'], mean=config['mean'], std=config['std'], interpolation=config['interpolation'], is_training=False ) # Preprocess the image tensor = transform(image).unsqueeze(0) # Forward pass with torch.no_grad(): _ = model(tensor) # Extract attention maps attention_maps = extractor.get_attention_maps() return attention_maps def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray: # Ensure mask and image have the same shape mask = mask[:, :, np.newaxis] mask = np.repeat(mask, 3, axis=2) # Convert color to numpy array color = np.array(color) # Apply mask masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255 return masked_image.astype(np.uint8) def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image]: """Visualize attention maps for the given image and model.""" model, extractor = load_model(model_name) attention_maps = process_image(image, model, extractor) num_prefix_tokens = getattr(model, 'num_prefix_tokens', 0) # Convert PIL Image to numpy array image_np = np.array(image) # Create visualizations visualizations = [] for layer_name, attn_map in attention_maps.items(): print(f"Attention map shape for {layer_name}: {attn_map.shape}") # Remove the CLS token attention and average over heads attn_map = attn_map[0, :, 0, num_prefix_tokens:].mean(0) # Shape: (seq_len-1,) # Reshape the attention map to 2D num_patches = int(np.sqrt(attn_map.shape[0])) attn_map = attn_map.reshape(num_patches, num_patches) # Interpolate to match image size attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0) attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False) attn_map = attn_map.squeeze().cpu().numpy() # Normalize attention map attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) # Create visualization fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) # Original image ax1.imshow(image_np) ax1.set_title("Original Image") ax1.axis('off') # Attention map overlay masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0)) # Red mask ax2.imshow(masked_image) ax2.set_title(f'Attention Map for {layer_name}') ax2.axis('off') plt.tight_layout() # Convert plot to image fig.canvas.draw() vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) visualizations.append(vis_image) plt.close(fig) return visualizations # Create Gradio interface iface = gr.Interface( fn=visualize_attention, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown(choices=get_attention_models(), label="Select Model") ], outputs=gr.Gallery(label="Attention Maps"), title="Attention Map Visualizer for timm Models", description="Upload an image and select a timm model to visualize its attention maps." ) iface.launch(debug=True)