from typing import List, Tuple, Dict from collections import OrderedDict import gradio as gr import torch import torch.nn.functional as F import timm from timm.data import create_transform from timm.models import create_model from timm.utils import AttentionExtract from PIL import Image import numpy as np import matplotlib.pyplot as plt def get_attention_models() -> List[str]: """Get a list of timm models that have attention blocks.""" all_models = timm.list_pretrained() # FIXME Focusing on ViT models for initial impl attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])] return attention_models def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]: """Load a model from timm and prepare it for attention extraction.""" timm.layers.set_fused_attn(False) model = create_model(model_name, pretrained=True) model.eval() extractor = AttentionExtract(model, method='fx') # can use 'hooks', can also allow specifying matching names for attention nodes or modules... return model, extractor def process_image( image: Image.Image, model: torch.nn.Module, extractor: AttentionExtract ) -> Dict[str, torch.Tensor]: """Process the input image and get the attention maps.""" # Get the correct transform for the model config = model.pretrained_cfg transform = create_transform( input_size=config['input_size'], crop_pct=config['crop_pct'], mean=config['mean'], std=config['std'], interpolation=config['interpolation'], is_training=False ) # Preprocess the image tensor = transform(image).unsqueeze(0) # Extract attention maps attention_maps = extractor(tensor) return attention_maps def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray: # Ensure mask and image have the same shape mask = mask[:, :, np.newaxis] mask = np.repeat(mask, 3, axis=2) # Convert color to numpy array color = np.array(color) # Apply mask masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255 return masked_image.astype(np.uint8) def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1): # based on https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py result = torch.eye(attentions[0].size(-1)) with torch.no_grad(): for attention in attentions: if head_fusion.startswith('mean'): # mean_std fusion doesn't appear to make sense with rollout attention_heads_fused = attention.mean(dim=0) elif head_fusion == "max": attention_heads_fused = attention.amax(dim=0) elif head_fusion == "min": attention_heads_fused = attention.amin(dim=0) else: raise ValueError("Attention head fusion type Not supported") # Discard the lowest attentions, but don't discard the prefix tokens flat = attention_heads_fused.view(-1) _, indices = flat.topk(int(flat.size(-1 )* discard_ratio), -1, False) indices = indices[indices >= num_prefix_tokens] flat[indices] = 0 I = torch.eye(attention_heads_fused.size(-1)) a = (attention_heads_fused + 1.0 * I) / 2 a = a / a.sum(dim=-1) result = torch.matmul(a, result) # Look at the total attention between the prefix tokens (usually class tokens) # and the image patches # FIXME this is token 0 vs non-prefix right now, need to cover other cases (> 1 prefix, no prefix, etc) mask = result[0, num_prefix_tokens:] width = int(mask.size(-1) ** 0.5) mask = mask.reshape(width, width).numpy() mask = mask / np.max(mask) return mask def visualize_attention( image: Image.Image, model_name: str, head_fusion: str, discard_ratio: float, ) -> Tuple[List[Image.Image], Image.Image]: """Visualize attention maps and rollout for the given image and model.""" model, extractor = load_model(model_name) attention_maps = process_image(image, model, extractor) # FIXME handle wider range of models that may not have num_prefix_tokens attr num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified # Convert PIL Image to numpy array image_np = np.array(image) # Create visualizations visualizations = [] attentions_for_rollout = [] for layer_name, attn_map in attention_maps.items(): print(f"Attention map shape for {layer_name}: {attn_map.shape}") attn_map = attn_map[0] # Remove batch dimension attentions_for_rollout.append(attn_map) attn_map = attn_map[:, :, num_prefix_tokens:] # Remove prefix tokens for visualization if head_fusion == 'mean_std': attn_map = attn_map.mean(0) / attn_map.std(0) elif head_fusion == 'mean': attn_map = attn_map.mean(0) elif head_fusion == 'max': attn_map = attn_map.amax(0) elif head_fusion == 'min': attn_map = attn_map.amin(0) else: raise ValueError(f"Invalid head fusion method: {head_fusion}") # Use the first token's attention (usually the class token) # FIXME handle different prefix token scenarios attn_map = attn_map[0] # Reshape the attention map to 2D num_patches = int(attn_map.shape[0] ** 0.5) attn_map = attn_map.reshape(num_patches, num_patches) # Interpolate to match image size attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0) attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False) attn_map = attn_map.squeeze().cpu().numpy() # Normalize attention map attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) # Create visualization fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) # Original image ax1.imshow(image_np) ax1.set_title("Original Image") ax1.axis('off') # Attention map overlay masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0)) # Red mask ax2.imshow(masked_image) ax2.set_title(f'Attention Map for {layer_name}') ax2.axis('off') plt.tight_layout() # Convert plot to image fig.canvas.draw() vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) visualizations.append(vis_image) plt.close(fig) # Calculate rollout rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens) # Create rollout visualization fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) # Original image ax1.imshow(image_np) ax1.set_title("Original Image") ax1.axis('off') # Rollout overlay rollout_mask_pil = Image.fromarray((rollout_mask * 255).astype(np.uint8)) rollout_mask_resized = np.array(rollout_mask_pil.resize((image_np.shape[1], image_np.shape[0]), Image.BICUBIC)) / 255.0 masked_image = apply_mask(image_np, rollout_mask_resized, color=(1, 0, 0)) # Red mask ax2.imshow(masked_image) ax2.set_title('Attention Rollout') ax2.axis('off') plt.tight_layout() # Convert plot to image fig.canvas.draw() rollout_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) plt.close(fig) return visualizations, rollout_image # Create Gradio interface iface = gr.Interface( fn=visualize_attention, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown(choices=get_attention_models(), label="Select Model"), gr.Dropdown( choices=['mean_std', 'mean', 'max', 'min'], label="Head Fusion Method", value='mean' # Default value ), gr.Slider(0, 1, 0.9, label="Discard Ratio", info="Ratio of lowest attentions to discard") ], outputs=[ gr.Gallery(label="Attention Maps"), gr.Image(label="Attention Rollout") ], title="Attention Map Visualizer for timm Models", description="Upload an image and select a timm model to visualize its attention maps." ) iface.launch()