rwightman HF staff commited on
Commit
cc361dd
1 Parent(s): 77f3515

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -20,7 +20,7 @@ def get_attention_models() -> List[str]:
20
  attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])]
21
  return attention_models
22
 
23
- def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
24
  """Load a model from timm and prepare it for attention extraction."""
25
  timm.layers.set_fused_attn(False)
26
  model = create_model(model_name, pretrained=True)
@@ -28,7 +28,11 @@ def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
28
  extractor = AttentionExtract(model, method='fx') # can use 'hooks', can also allow specifying matching names for attention nodes or modules...
29
  return model, extractor
30
 
31
- def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]:
 
 
 
 
32
  """Process the input image and get the attention maps."""
33
  # Get the correct transform for the model
34
  config = model.pretrained_cfg
 
20
  attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])]
21
  return attention_models
22
 
23
+ def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]:
24
  """Load a model from timm and prepare it for attention extraction."""
25
  timm.layers.set_fused_attn(False)
26
  model = create_model(model_name, pretrained=True)
 
28
  extractor = AttentionExtract(model, method='fx') # can use 'hooks', can also allow specifying matching names for attention nodes or modules...
29
  return model, extractor
30
 
31
+ def process_image(
32
+ image: Image.Image,
33
+ model: torch.nn.Module,
34
+ extractor: AttentionExtract
35
+ ) -> Dict[str, torch.Tensor]:
36
  """Process the input image and get the attention maps."""
37
  # Get the correct transform for the model
38
  config = model.pretrained_cfg