typorch commited on
Commit
3be2ac3
1 Parent(s): 463b38f

Upload 8 files

Browse files
README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: timm
4
+ ---
5
+ # WD ViT Tagger v3
6
+
7
+ Supports ratings, characters and general tags.
8
+
9
+ Trained using https://github.com/SmilingWolf/JAX-CV.
10
+ TPUs used for training kindly provided by the [TRC program](https://sites.research.google/trc/about/).
11
+
12
+ ## Dataset
13
+ Last image id: 7220105
14
+ Trained on Danbooru images with IDs modulo 0000-0899.
15
+ Validated on images with IDs modulo 0950-0999.
16
+ Images with less than 10 general tags were filtered out.
17
+ Tags with less than 600 images were filtered out.
18
+
19
+ ## Validation results
20
+ `v2.0: P=R: threshold = 0.2614, F1 = 0.4402`
21
+ `v1.0: P=R: threshold = 0.2547, F1 = 0.4278`
22
+
23
+ ## What's new
24
+ Model v2.0/Dataset v3:
25
+ Trained for a few more epochs.
26
+ Used tag frequency-based loss scaling to combat class imbalance.
27
+
28
+ Model v1.1/Dataset v3:
29
+ Amended the JAX model config file: add image size.
30
+ No change to the trained weights.
31
+
32
+ Model v1.0/Dataset v3:
33
+ More training images, more and up-to-date tags (up to 2024-02-28).
34
+ Now `timm` compatible! Load it up and give it a spin using the canonical one-liner!
35
+ ONNX model is compatible with code developed for the v2 series of models.
36
+ The batch dimension of the ONNX model is not fixed to 1 anymore. Now you can go crazy with batch inference.
37
+ Switched to Macro-F1 to measure model performance since it gives me a better gauge of overall training progress.
38
+
39
+ # Runtime deps
40
+ ONNX model requires `onnxruntime >= 1.17.0`
41
+
42
+ # Inference code examples
43
+ For timm: https://github.com/neggles/wdv3-timm
44
+ For ONNX: https://huggingface.co/spaces/SmilingWolf/wd-tagger
45
+ For JAX: https://github.com/SmilingWolf/wdv3-jax
46
+
47
+ ## Final words
48
+ Subject to change and updates.
49
+ Downstream users are encouraged to use tagged releases rather than relying on the head of the repo.
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "vit_base_patch16_224",
3
+ "num_classes": 10861,
4
+ "num_features": 768,
5
+ "global_pool": "avg",
6
+ "model_args": {
7
+ "img_size": 448,
8
+ "class_token": false,
9
+ "global_pool": "avg",
10
+ "fc_norm": false,
11
+ "act_layer": "gelu_tanh"
12
+ },
13
+ "pretrained_cfg": {
14
+ "custom_load": false,
15
+ "input_size": [
16
+ 3,
17
+ 448,
18
+ 448
19
+ ],
20
+ "fixed_input_size": false,
21
+ "interpolation": "bicubic",
22
+ "crop_pct": 1.0,
23
+ "crop_mode": "center",
24
+ "mean": [
25
+ 0.5,
26
+ 0.5,
27
+ 0.5
28
+ ],
29
+ "std": [
30
+ 0.5,
31
+ 0.5,
32
+ 0.5
33
+ ],
34
+ "num_classes": 10861,
35
+ "pool_size": null,
36
+ "first_conv": null,
37
+ "classifier": null
38
+ }
39
+ }
model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8efa9946a6b967cb0566d5dc2ac6cf736344224273365e8082d1d49b0d987ae9
3
+ size 134
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7a2a3606a12e94a6a66331df2a3558958332565647f0f20be2e2fd56a37d73b
3
+ size 134
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d38d2378878efe09139d020e80d264861e3cb1e8472448029abf3785763c264d
3
+ size 134
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc9b75e28cbb1f6c342fac49c46bbdff7ba801b3a519ed4e002ddedcc181a0e5
3
+ size 378459814
selected_tags.csv ADDED
The diff for this file is too large to render. See raw diff
 
sw_jax_cv_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_size": 448,
3
+ "model_name": "vit_base",
4
+ "model_args": {
5
+ "patch_size": 16,
6
+ "num_classes": 10861,
7
+ "num_layers": 12,
8
+ "embed_dim": 768,
9
+ "mlp_dim": 3072,
10
+ "num_heads": 12,
11
+ "drop_path_rate": 0.1,
12
+ "layer_norm_eps": 1e-05
13
+ }
14
+ }