rwightman HF staff commited on
Commit
77f3515
1 Parent(s): c2b7e3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -34
app.py CHANGED
@@ -1,37 +1,17 @@
 
 
 
1
  import gradio as gr
2
  import torch
3
- import timm
4
  import torch.nn.functional as F
5
- from timm.models import create_model
6
  from timm.data import create_transform
 
 
7
  from PIL import Image
8
  import numpy as np
9
  import matplotlib.pyplot as plt
10
- from typing import List, Tuple, Dict
11
- from collections import OrderedDict
12
 
13
- class AttentionExtractor:
14
- def __init__(self, model: torch.nn.Module):
15
- self.model = model
16
- self.attention_maps = OrderedDict()
17
- self._register_hooks()
18
-
19
- def _register_hooks(self):
20
- def hook_fn(module, input, output):
21
- if isinstance(output, tuple):
22
- self.attention_maps[module.full_name] = output[1] # attention_probs
23
- else:
24
- self.attention_maps[module.full_name] = output
25
-
26
- for name, module in self.model.named_modules():
27
- # FIXME need to make more generic outside of vit
28
- if name.lower().endswith('.attn_drop'):
29
- module.full_name = name
30
- print('hooking', name)
31
- module.register_forward_hook(hook_fn)
32
-
33
- def get_attention_maps(self) -> OrderedDict:
34
- return self.attention_maps
35
 
36
  def get_attention_models() -> List[str]:
37
  """Get a list of timm models that have attention blocks."""
@@ -45,7 +25,7 @@ def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
45
  timm.layers.set_fused_attn(False)
46
  model = create_model(model_name, pretrained=True)
47
  model.eval()
48
- extractor = AttentionExtractor(model)
49
  return model, extractor
50
 
51
  def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]:
@@ -61,16 +41,11 @@ def process_image(image: Image.Image, model: torch.nn.Module, extractor: Attenti
61
  is_training=False
62
  )
63
 
64
-
65
  # Preprocess the image
66
  tensor = transform(image).unsqueeze(0)
67
-
68
- # Forward pass
69
- with torch.no_grad():
70
- _ = model(tensor)
71
-
72
  # Extract attention maps
73
- attention_maps = extractor.get_attention_maps()
74
 
75
  return attention_maps
76
 
 
1
+ from typing import List, Tuple, Dict
2
+ from collections import OrderedDict
3
+
4
  import gradio as gr
5
  import torch
 
6
  import torch.nn.functional as F
7
+ import timm
8
  from timm.data import create_transform
9
+ from timm.models import create_model
10
+ from timm.utils import AttentionExtract
11
  from PIL import Image
12
  import numpy as np
13
  import matplotlib.pyplot as plt
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def get_attention_models() -> List[str]:
17
  """Get a list of timm models that have attention blocks."""
 
25
  timm.layers.set_fused_attn(False)
26
  model = create_model(model_name, pretrained=True)
27
  model.eval()
28
+ extractor = AttentionExtract(model, method='fx') # can use 'hooks', can also allow specifying matching names for attention nodes or modules...
29
  return model, extractor
30
 
31
  def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]:
 
41
  is_training=False
42
  )
43
 
 
44
  # Preprocess the image
45
  tensor = transform(image).unsqueeze(0)
46
+
 
 
 
 
47
  # Extract attention maps
48
+ attention_maps = extractor(tensor)
49
 
50
  return attention_maps
51