tangjicheng commited on
Commit
18da417
1 Parent(s): 7a21097

new file: app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import onnxruntime as rt
4
+ import gradio as gr
5
+ from huggingface_hub import hf_hub_download
6
+ from dataclasses import dataclass
7
+
8
+ tagger_model_path = hf_hub_download(
9
+ repo_id="skytnt/deepdanbooru_onnx", filename="deepdanbooru.onnx")
10
+
11
+ tagger_model = rt.InferenceSession(
12
+ tagger_model_path, providers=['CPUExecutionProvider'])
13
+ tagger_model_meta = tagger_model.get_modelmeta().custom_metadata_map
14
+ tagger_tags = eval(tagger_model_meta['tags'])
15
+
16
+
17
+ @dataclass
18
+ class Tag:
19
+ lable: str
20
+ prob: float
21
+
22
+
23
+ def tagger_predict(image, score_threshold):
24
+ image = np.array(image)
25
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
26
+ s = 512
27
+ h, w = image.shape[:-1]
28
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
29
+ ph, pw = s - h, s - w
30
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
31
+ image = cv2.copyMakeBorder(
32
+ image, ph // 2, ph - ph // 2, pw // 2, pw - pw // 2, cv2.BORDER_REPLICATE)
33
+ image = image.astype(np.float32) / 255
34
+ image = image[np.newaxis, :]
35
+ probs = tagger_model.run(None, {"input_1": image})[0][0]
36
+ probs = probs.astype(np.float32)
37
+ res = []
38
+ for prob, label in zip(probs.tolist(), tagger_tags):
39
+ if prob < score_threshold:
40
+ continue
41
+ res.append(Tag(label, prob))
42
+ sorted_res = sorted(res, key=lambda Tag: Tag.prob, reverse=True)
43
+ output_string = ""
44
+ output_string_without_prob = ""
45
+ for iter in sorted_res:
46
+ output_string += iter.lable + f" : {iter.prob:.2f}\n"
47
+ output_string_without_prob += iter.lable + "\n"
48
+ output_string = output_string[:-1]
49
+ output_string_without_prob = output_string_without_prob[:-1]
50
+ return (output_string, output_string_without_prob)
51
+
52
+
53
+ def gradio_wrapper(image, score_threshold):
54
+ return tagger_predict(image, score_threshold)
55
+
56
+
57
+ inputs = gr.inputs.Image()
58
+ slider = gr.inputs.Slider(minimum=0, maximum=1, default=0.5)
59
+ outputs = gr.outputs.Textbox()
60
+ outputs_list = gr.outputs.Textbox()
61
+
62
+ iface = gr.Interface(fn=gradio_wrapper,
63
+ inputs=[inputs, slider],
64
+ outputs=[outputs, outputs_list])
65
+
66
+ iface.launch()