Image Classification
mlx-image
Safetensors
MLX
vision
riccardomusmeci commited on
Commit
c0c08a5
1 Parent(s): 591e456

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +59 -54
README.md CHANGED
@@ -1,76 +1,81 @@
1
  ---
2
- license: apache-2.0
3
- library_name: mlx-image
4
- tags:
5
- - mlx
6
- - mlx-image
7
- - vision
8
- - image-classification
9
- datasets:
10
- - imagenet-1k
11
  ---
12
- # vit_base_patch8_224.dino
13
 
14
- A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
17
 
18
- Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
19
 
20
- <div align="center">
21
- <img width="100%" alt="DINO illustration" src="dino.gif">
22
- </div>
23
 
 
 
 
24
 
25
- ## How to use
26
- ```bash
27
- pip install mlx-image
28
- ```
29
 
30
- Here is how to use this model for image classification:
 
 
 
31
 
32
- ```python
33
- from mlxim.model import create_model
34
- from mlxim.io import read_rgb
35
- from mlxim.transform import ImageNetTransform
36
 
37
- transform = ImageNetTransform(train=False, img_size=224)
38
- x = transform(read_rgb("cat.png"))
39
- x = mx.expand_dims(x, 0)
 
40
 
41
- model = create_model("vit_base_patch8_224.dino")
42
- model.eval()
 
43
 
44
- logits, attn_masks = model(x, attn_masks=True)
45
- ```
46
 
47
- You can also use the embeds from layer before head:
48
- ```python
49
- from mlxim.model import create_model
50
- from mlxim.io import read_rgb
51
- from mlxim.transform import ImageNetTransform
52
 
53
- transform = ImageNetTransform(train=False, img_size=512)
54
- x = transform(read_rgb("cat.png"))
55
- x = mx.expand_dims(x, 0)
 
 
56
 
57
- # first option
58
- model = create_model("vit_base_patch8_224.dino", num_classes=0)
59
- model.eval()
60
 
61
- embeds = model(x)
 
 
62
 
63
- # second option
64
- model = create_model("vit_base_patch8_224.dino")
65
- model.eval()
66
 
67
- embeds, attn_masks = model.get_features(x)
68
- ```
 
69
 
70
- ## Attention maps
71
- You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/notebooks/dino_attention.ipynb).
72
 
73
- <div align="center">
74
- <img width="100%" alt="Attention Map" src="attention_maps.png">
75
- </div>
76
 
 
 
 
 
 
 
1
  ---
2
+ {}
 
 
 
 
 
 
 
 
3
  ---
 
4
 
5
+ ---
6
+ license: apache-2.0
7
+ tags:
8
+ - mlx
9
+ - mlx-image
10
+ - vision
11
+ - image-classification
12
+ datasets:
13
+ - imagenet-1k
14
+ library_name: mlx-image
15
+ ---
16
+ # vit_base_patch8_224.dino
17
 
18
+ A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
19
 
20
+ The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
21
 
22
+ Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
 
 
23
 
24
+ <div align="center">
25
+ <img width="100%" alt="DINO illustration" src="dino.gif">
26
+ </div>
27
 
 
 
 
 
28
 
29
+ ## How to use
30
+ ```bash
31
+ pip install mlx-image
32
+ ```
33
 
34
+ Here is how to use this model for image classification:
 
 
 
35
 
36
+ ```python
37
+ from mlxim.model import create_model
38
+ from mlxim.io import read_rgb
39
+ from mlxim.transform import ImageNetTransform
40
 
41
+ transform = ImageNetTransform(train=False, img_size=224)
42
+ x = transform(read_rgb("cat.png"))
43
+ x = mx.expand_dims(x, 0)
44
 
45
+ model = create_model("vit_base_patch8_224.dino")
46
+ model.eval()
47
 
48
+ logits, attn_masks = model(x, attn_masks=True)
49
+ ```
 
 
 
50
 
51
+ You can also use the embeds from layer before head:
52
+ ```python
53
+ from mlxim.model import create_model
54
+ from mlxim.io import read_rgb
55
+ from mlxim.transform import ImageNetTransform
56
 
57
+ transform = ImageNetTransform(train=False, img_size=512)
58
+ x = transform(read_rgb("cat.png"))
59
+ x = mx.expand_dims(x, 0)
60
 
61
+ # first option
62
+ model = create_model("vit_base_patch8_224.dino", num_classes=0)
63
+ model.eval()
64
 
65
+ embeds = model(x)
 
 
66
 
67
+ # second option
68
+ model = create_model("vit_base_patch8_224.dino")
69
+ model.eval()
70
 
71
+ embeds, attn_masks = model.get_features(x)
72
+ ```
73
 
74
+ ## Attention maps
75
+ You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).
 
76
 
77
+ <div align="center">
78
+ <img width="100%" alt="Attention Map" src="attention_maps.png">
79
+ </div>
80
+
81
+