rwightman HF staff commited on
Commit
7163838
1 Parent(s): 8af18ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -24,6 +24,7 @@ class AttentionExtractor:
24
  self.attention_maps[module.full_name] = output
25
 
26
  for name, module in self.model.named_modules():
 
27
  if name.lower().endswith('.attn_drop'):
28
  module.full_name = name
29
  print('hooking', name)
@@ -34,8 +35,9 @@ class AttentionExtractor:
34
 
35
  def get_attention_models() -> List[str]:
36
  """Get a list of timm models that have attention blocks."""
37
- all_models = timm.list_models()
38
- attention_models = [model for model in all_models if 'vit' in model.lower()] # Focusing on ViT models for simplicity
 
39
  return attention_models
40
 
41
  def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
@@ -88,7 +90,8 @@ def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image
88
  """Visualize attention maps for the given image and model."""
89
  model, extractor = load_model(model_name)
90
  attention_maps = process_image(image, model, extractor)
91
-
 
92
  # Convert PIL Image to numpy array
93
  image_np = np.array(image)
94
 
@@ -97,8 +100,8 @@ def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image
97
  for layer_name, attn_map in attention_maps.items():
98
  print(f"Attention map shape for {layer_name}: {attn_map.shape}")
99
 
100
- # Remove the CLS token attention and average over heads
101
- attn_map = attn_map[0, :, 0, 1:].mean(0) # Shape: (seq_len-1,)
102
 
103
  # Reshape the attention map to 2D
104
  num_patches = int(np.sqrt(attn_map.shape[0]))
 
24
  self.attention_maps[module.full_name] = output
25
 
26
  for name, module in self.model.named_modules():
27
+ # FIXME need to make more generic outside of vit
28
  if name.lower().endswith('.attn_drop'):
29
  module.full_name = name
30
  print('hooking', name)
 
35
 
36
  def get_attention_models() -> List[str]:
37
  """Get a list of timm models that have attention blocks."""
38
+ all_models = timm.list_pretrained()
39
+ # FIXME Focusing on ViT models for initial impl
40
+ attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')]
41
  return attention_models
42
 
43
  def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
 
90
  """Visualize attention maps for the given image and model."""
91
  model, extractor = load_model(model_name)
92
  attention_maps = process_image(image, model, extractor)
93
+ num_prefix_tokens = getattr(model, 'num_prefix_tokens', 0)
94
+
95
  # Convert PIL Image to numpy array
96
  image_np = np.array(image)
97
 
 
100
  for layer_name, attn_map in attention_maps.items():
101
  print(f"Attention map shape for {layer_name}: {attn_map.shape}")
102
 
103
+ # Remove the CLS token attention and average over heads
104
+ attn_map = attn_map[0, :, 0, num_prefix_tokens:].mean(0) # Shape: (seq_len-1,)
105
 
106
  # Reshape the attention map to 2D
107
  num_patches = int(np.sqrt(attn_map.shape[0]))