riccardomusmeci commited on
Commit
1e78889
1 Parent(s): 531604b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -48
README.md CHANGED
@@ -1,64 +1,62 @@
1
-
2
  ---
3
- license: apache-2.0
4
- tags:
5
- - mlx
6
- - mlx-image
7
- - vision
8
- - image-classification
9
- datasets:
10
- - imagenet-1k
11
- library_name: mlx-image
12
- ---
13
- # regnet_y_800mf
14
 
15
- A RegNetY-800MF image classification model. Pretrained in ImageNet by torchvision contributors (see ImageNet1K-V2 weight details https://github.com/pytorch/vision/issues/3995#new-recipe).
16
 
17
- Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
18
 
19
- ## How to use
20
- ```bash
21
- pip install mlx-image
22
- ```
23
 
24
- Here is how to use this model for image classification:
 
 
 
25
 
26
- ```python
27
- from mlxim.model import create_model
28
- from mlxim.io import read_rgb
29
- from mlxim.transform import ImageNetTransform
30
 
31
- transform = ImageNetTransform(train=False, img_size=224)
32
- x = transform(read_rgb("cat.png"))
33
- x = mx.expand_dims(x, 0)
 
34
 
35
- model = create_model("regnet_y_800mf")
36
- model.eval()
 
37
 
38
- logits = model(x)
39
- ```
40
 
41
- You can also use the embeds from layer before head:
42
- ```python
43
- from mlxim.model import create_model
44
- from mlxim.io import read_rgb
45
- from mlxim.transform import ImageNetTransform
46
 
47
- transform = ImageNetTransform(train=False, img_size=224)
48
- x = transform(read_rgb("cat.png"))
49
- x = mx.expand_dims(x, 0)
 
 
50
 
51
- # first option
52
- model = create_model("regnet_y_800mf", num_classes=0)
53
- model.eval()
54
 
55
- embeds = model(x)
 
 
56
 
57
- # second option
58
- model = create_model("regnet_y_800mf")
59
- model.eval()
60
 
61
- embeds = model.get_features(x)
62
- ```
 
63
 
64
-
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - ILSVRC/imagenet-1k
5
+ tags:
6
+ - mlx
7
+ - mlx-image
8
+ - vision
9
+ - image-classification
10
+ library_name: mlx-image
11
+ ---
 
12
 
13
+ # regnet_y_800mf
14
 
15
+ A RegNetY-800MF image classification model. Pretrained in ImageNet by torchvision contributors (see ImageNet1K-V2 weight details https://github.com/pytorch/vision/issues/3995#new-recipe).
16
 
17
+ Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
 
 
 
18
 
19
+ ## How to use
20
+ ```bash
21
+ pip install mlx-image
22
+ ```
23
 
24
+ Here is how to use this model for image classification:
 
 
 
25
 
26
+ ```python
27
+ from mlxim.model import create_model
28
+ from mlxim.io import read_rgb
29
+ from mlxim.transform import ImageNetTransform
30
 
31
+ transform = ImageNetTransform(train=False, img_size=224)
32
+ x = transform(read_rgb("cat.png"))
33
+ x = mx.expand_dims(x, 0)
34
 
35
+ model = create_model("regnet_y_800mf")
36
+ model.eval()
37
 
38
+ logits = model(x)
39
+ ```
 
 
 
40
 
41
+ You can also use the embeds from layer before head:
42
+ ```python
43
+ from mlxim.model import create_model
44
+ from mlxim.io import read_rgb
45
+ from mlxim.transform import ImageNetTransform
46
 
47
+ transform = ImageNetTransform(train=False, img_size=224)
48
+ x = transform(read_rgb("cat.png"))
49
+ x = mx.expand_dims(x, 0)
50
 
51
+ # first option
52
+ model = create_model("regnet_y_800mf", num_classes=0)
53
+ model.eval()
54
 
55
+ embeds = model(x)
 
 
56
 
57
+ # second option
58
+ model = create_model("regnet_y_800mf")
59
+ model.eval()
60
 
61
+ embeds = model.get_features(x)
62
+ ```