Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ def get_attention_models() -> List[str]:
|
|
20 |
attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])]
|
21 |
return attention_models
|
22 |
|
23 |
-
def load_model(model_name: str) -> Tuple[torch.nn.Module,
|
24 |
"""Load a model from timm and prepare it for attention extraction."""
|
25 |
timm.layers.set_fused_attn(False)
|
26 |
model = create_model(model_name, pretrained=True)
|
@@ -28,7 +28,11 @@ def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]:
|
|
28 |
extractor = AttentionExtract(model, method='fx') # can use 'hooks', can also allow specifying matching names for attention nodes or modules...
|
29 |
return model, extractor
|
30 |
|
31 |
-
def process_image(
|
|
|
|
|
|
|
|
|
32 |
"""Process the input image and get the attention maps."""
|
33 |
# Get the correct transform for the model
|
34 |
config = model.pretrained_cfg
|
|
|
20 |
attention_models = [model for model in all_models if any([model.lower().startswith(p) for p in ('vit', 'deit', 'beit', 'eva')])]
|
21 |
return attention_models
|
22 |
|
23 |
+
def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtract]:
|
24 |
"""Load a model from timm and prepare it for attention extraction."""
|
25 |
timm.layers.set_fused_attn(False)
|
26 |
model = create_model(model_name, pretrained=True)
|
|
|
28 |
extractor = AttentionExtract(model, method='fx') # can use 'hooks', can also allow specifying matching names for attention nodes or modules...
|
29 |
return model, extractor
|
30 |
|
31 |
+
def process_image(
|
32 |
+
image: Image.Image,
|
33 |
+
model: torch.nn.Module,
|
34 |
+
extractor: AttentionExtract
|
35 |
+
) -> Dict[str, torch.Tensor]:
|
36 |
"""Process the input image and get the attention maps."""
|
37 |
# Get the correct transform for the model
|
38 |
config = model.pretrained_cfg
|