lexandstuff commited on
Commit
51fe5bd
·
verified ·
1 Parent(s): 129048a

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +66 -3
  2. model.safetensors +3 -0
README.md CHANGED
@@ -1,3 +1,66 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: mlx-image
4
+ tags:
5
+ - mlx
6
+ - mlx-image
7
+ - vision
8
+ - image-classification
9
+ datasets:
10
+ - imagenet-1k
11
+ ---
12
+
13
+ # efficientnet_b1
14
+
15
+ An EfficientNet B1 model architecture, pretrained on ImageNet-1K.
16
+
17
+ Disclaimer: this is a port of the Torchvision model weights to Apple MLX Framework.
18
+
19
+ See [mlx-convert-scripts](https://github.com/lextoumbourou/mlx-convert-scripts) repo for the conversion script used.
20
+
21
+ ## How to use
22
+
23
+ ```bash
24
+ pip install mlx-image
25
+ ```
26
+
27
+ Here is how to use this model for image classification:
28
+
29
+ ```python
30
+ from mlxim.model import create_model
31
+ from mlxim.io import read_rgb
32
+ from mlxim.transform import ImageNetTransform
33
+
34
+ transform = ImageNetTransform(train=False, img_size=240)
35
+ x = transform(read_rgb("cat.png"))
36
+ x = mx.expand_dims(x, 0)
37
+
38
+ model = create_model("efficientnet_b1")
39
+ model.eval()
40
+
41
+ logits = model(x)
42
+ ```
43
+
44
+ You can also use the embeds from layer before head:
45
+
46
+ ```python
47
+ from mlxim.model import create_model
48
+ from mlxim.io import read_rgb
49
+ from mlxim.transform import ImageNetTransform
50
+
51
+ transform = ImageNetTransform(train=False, img_size=240)
52
+ x = transform(read_rgb("cat.png"))
53
+ x = mx.expand_dims(x, 0)
54
+
55
+ # first option
56
+ model = create_model("efficientnet_b1", num_classes=0)
57
+ model.eval()
58
+
59
+ embeds = model(x)
60
+
61
+ # second option
62
+ model = create_model("efficientnet_b1")
63
+ model.eval()
64
+
65
+ embeds = model.get_features(x)
66
+ ```
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b15be1fa42b15e4a4bd69eebcef8cd4b2182d28a09e093d0a51e68fb69667cd
3
+ size 31471286