Commit
1f4a582
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: fancyfeast <[email protected]>
Co-authored-by: SmilingWolf <[email protected]>
Co-authored-by: fancyfeast <[email protected]>

Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +80 -0
  4. requirements.txt +2 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: E621 Tagger (Z3D E621 Convnext)
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import huggingface_hub
3
+ from PIL import Image
4
+ from pathlib import Path
5
+ import onnxruntime as rt
6
+ import numpy as np
7
+ import csv
8
+ import spaces
9
+
10
+ import onnxruntime as rt
11
+ e621_model_path = Path(huggingface_hub.snapshot_download('toynya/Z3D-E621-Convnext'))
12
+ e621_model_session = rt.InferenceSession(e621_model_path / 'model.onnx', providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
13
+ with open(e621_model_path / 'tags-selected.csv', mode='r', encoding='utf-8') as file:
14
+ csv_reader = csv.DictReader(file)
15
+ e621_model_tags = [row['name'].strip() for row in csv_reader]
16
+
17
+
18
+ def prepare_image_e621(image: Image.Image, target_size: int):
19
+ import numpy as np
20
+ # Pad image to square
21
+ image_shape = image.size
22
+ max_dim = max(image_shape)
23
+ pad_left = (max_dim - image_shape[0]) // 2
24
+ pad_top = (max_dim - image_shape[1]) // 2
25
+
26
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
27
+ padded_image.paste(image, (pad_left, pad_top))
28
+
29
+ # Resize
30
+ if max_dim != target_size:
31
+ padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
32
+
33
+ # Convert to numpy array
34
+ # Based on the ONNX graph, the model appears to expect inputs in the range of 0-255
35
+ image_array = np.asarray(padded_image, dtype=np.float32)
36
+
37
+ # Convert PIL-native RGB to BGR
38
+ image_array = image_array[:, :, ::-1]
39
+
40
+ return np.expand_dims(image_array, axis=0)
41
+
42
+
43
+ def predict_e621(image: Image.Image):
44
+ THRESHOLD = 0.3
45
+ image_array = prepare_image_e621(image, 448)
46
+
47
+ image_array = prepare_image_e621(image, 448)
48
+ input_name = 'input_1:0'
49
+ output_name = 'predictions_sigmoid'
50
+
51
+ result = e621_model_session.run([output_name], {input_name: image_array})
52
+ result = result[0][0]
53
+
54
+ scores = {e621_model_tags[i]: result[i] for i in range(len(result))}
55
+ predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD]
56
+ tag_string = ', '.join(predicted_tags).replace("_", " ")
57
+
58
+ return tag_string, scores
59
+
60
+
61
+ DESCRIPTION = """
62
+ E621 Tagger (Z3D-E621-Convnext)
63
+ - Image => E621 Pony Prompt
64
+ - Mod of [fancyfeast's demo](https://huggingface.co/spaces/fancyfeast/Z3D-E621-Convnext-space) for toynya's [Z3D-E621-Convnext](https://huggingface.co/toynya/Z3D-E621-Convnext)
65
+ """
66
+
67
+ gradio_app = gr.Interface(
68
+ predict_e621,
69
+ inputs=gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil'),
70
+ outputs=[
71
+ gr.Textbox(label="Tag String", show_copy_button=True),
72
+ gr.Label(label="Tag Predictions", num_top_classes=100),
73
+ ],
74
+ description=DESCRIPTION,
75
+ allow_flagging="never",
76
+ )
77
+
78
+
79
+ if __name__ == '__main__':
80
+ gradio_app.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ numpy==1.26.3
2
+ onnxruntime==1.16.3