Image Classification
mlx-image
Safetensors
MLX
vision
riccardomusmeci commited on
Commit
0ae62fc
1 Parent(s): e463c8d

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +76 -0
README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: mlx-image
4
+ tags:
5
+ - mlx
6
+ - mlx-image
7
+ - vision
8
+ - image-classification
9
+ datasets:
10
+ - imagenet-1k
11
+ ---
12
+ # vit_base_patch8_224.dino
13
+
14
+ A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
15
+
16
+ The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
17
+
18
+ Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
19
+
20
+ <div align="center">
21
+ <img width="100%" alt="DINO illustration" src="dino.gif">
22
+ </div>
23
+
24
+
25
+ ## How to use
26
+ ```bash
27
+ pip install mlx-image
28
+ ```
29
+
30
+ Here is how to use this model for image classification:
31
+
32
+ ```python
33
+ from mlxim.model import create_model
34
+ from mlxim.io import read_rgb
35
+ from mlxim.transform import ImageNetTransform
36
+
37
+ transform = ImageNetTransform(train=False, img_size=224)
38
+ x = transform(read_rgb("cat.png"))
39
+ x = mx.expand_dims(x, 0)
40
+
41
+ model = create_model("vit_base_patch8_224.dino")
42
+ model.eval()
43
+
44
+ logits, attn_masks = model(x, attn_masks=True)
45
+ ```
46
+
47
+ You can also use the embeds from layer before head:
48
+ ```python
49
+ from mlxim.model import create_model
50
+ from mlxim.io import read_rgb
51
+ from mlxim.transform import ImageNetTransform
52
+
53
+ transform = ImageNetTransform(train=False, img_size=512)
54
+ x = transform(read_rgb("cat.png"))
55
+ x = mx.expand_dims(x, 0)
56
+
57
+ # first option
58
+ model = create_model("vit_base_patch8_224.dino", num_classes=0)
59
+ model.eval()
60
+
61
+ embeds = model(x)
62
+
63
+ # second option
64
+ model = create_model("vit_base_patch8_224.dino")
65
+ model.eval()
66
+
67
+ embeds, attn_masks = model.get_features(x)
68
+ ```
69
+
70
+ ## Attention maps
71
+ You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/notebooks/dino_attention.ipynb).
72
+
73
+ <div align="center">
74
+ <img width="100%" alt="Attention Map" src="attention_maps.png">
75
+ </div>
76
+