Spaces:
Runtime error
Runtime error
tangjicheng
commited on
Commit
•
18da417
1
Parent(s):
7a21097
new file: app.py
Browse files
app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime as rt
|
4 |
+
import gradio as gr
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
tagger_model_path = hf_hub_download(
|
9 |
+
repo_id="skytnt/deepdanbooru_onnx", filename="deepdanbooru.onnx")
|
10 |
+
|
11 |
+
tagger_model = rt.InferenceSession(
|
12 |
+
tagger_model_path, providers=['CPUExecutionProvider'])
|
13 |
+
tagger_model_meta = tagger_model.get_modelmeta().custom_metadata_map
|
14 |
+
tagger_tags = eval(tagger_model_meta['tags'])
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class Tag:
|
19 |
+
lable: str
|
20 |
+
prob: float
|
21 |
+
|
22 |
+
|
23 |
+
def tagger_predict(image, score_threshold):
|
24 |
+
image = np.array(image)
|
25 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
26 |
+
s = 512
|
27 |
+
h, w = image.shape[:-1]
|
28 |
+
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
|
29 |
+
ph, pw = s - h, s - w
|
30 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
|
31 |
+
image = cv2.copyMakeBorder(
|
32 |
+
image, ph // 2, ph - ph // 2, pw // 2, pw - pw // 2, cv2.BORDER_REPLICATE)
|
33 |
+
image = image.astype(np.float32) / 255
|
34 |
+
image = image[np.newaxis, :]
|
35 |
+
probs = tagger_model.run(None, {"input_1": image})[0][0]
|
36 |
+
probs = probs.astype(np.float32)
|
37 |
+
res = []
|
38 |
+
for prob, label in zip(probs.tolist(), tagger_tags):
|
39 |
+
if prob < score_threshold:
|
40 |
+
continue
|
41 |
+
res.append(Tag(label, prob))
|
42 |
+
sorted_res = sorted(res, key=lambda Tag: Tag.prob, reverse=True)
|
43 |
+
output_string = ""
|
44 |
+
output_string_without_prob = ""
|
45 |
+
for iter in sorted_res:
|
46 |
+
output_string += iter.lable + f" : {iter.prob:.2f}\n"
|
47 |
+
output_string_without_prob += iter.lable + "\n"
|
48 |
+
output_string = output_string[:-1]
|
49 |
+
output_string_without_prob = output_string_without_prob[:-1]
|
50 |
+
return (output_string, output_string_without_prob)
|
51 |
+
|
52 |
+
|
53 |
+
def gradio_wrapper(image, score_threshold):
|
54 |
+
return tagger_predict(image, score_threshold)
|
55 |
+
|
56 |
+
|
57 |
+
inputs = gr.inputs.Image()
|
58 |
+
slider = gr.inputs.Slider(minimum=0, maximum=1, default=0.5)
|
59 |
+
outputs = gr.outputs.Textbox()
|
60 |
+
outputs_list = gr.outputs.Textbox()
|
61 |
+
|
62 |
+
iface = gr.Interface(fn=gradio_wrapper,
|
63 |
+
inputs=[inputs, slider],
|
64 |
+
outputs=[outputs, outputs_list])
|
65 |
+
|
66 |
+
iface.launch()
|