Spaces:
Running
on
Zero
Running
on
Zero
tori29umai
commited on
Commit
•
c9dfb9e
1
Parent(s):
6477f29
Update
Browse files- app.py +166 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import csv
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import onnxruntime as ort
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
import spaces
|
12 |
+
|
13 |
+
# 画像のサイズ設定
|
14 |
+
IMAGE_SIZE = 448
|
15 |
+
|
16 |
+
def preprocess_image(image):
|
17 |
+
image = np.array(image)
|
18 |
+
image = image[:, :, ::-1] # BGRからRGBへ変換
|
19 |
+
|
20 |
+
# 画像を正方形にするためのパディングを追加
|
21 |
+
size = max(image.shape[0:2])
|
22 |
+
pad_x = size - image.shape[1]
|
23 |
+
pad_y = size - image.shape[0]
|
24 |
+
pad_l = pad_x // 2
|
25 |
+
pad_t = pad_y // 2
|
26 |
+
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
27 |
+
|
28 |
+
# サイズに合わせた補間方法を選択
|
29 |
+
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
30 |
+
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
31 |
+
image = image.astype(np.float32)
|
32 |
+
return image
|
33 |
+
|
34 |
+
@spaces.GPU
|
35 |
+
def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags):
|
36 |
+
thresh = 0.35
|
37 |
+
try:
|
38 |
+
image = Image.open(image_path)
|
39 |
+
image = image.convert("RGB") if image.mode != "RGB" else image
|
40 |
+
image = preprocess_image(image)
|
41 |
+
except Exception as e:
|
42 |
+
print(f"画像を読み込めません: {image_path}, エラー: {e}")
|
43 |
+
return
|
44 |
+
|
45 |
+
img = np.array([image])
|
46 |
+
prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
|
47 |
+
|
48 |
+
# NSFW/SFW判定
|
49 |
+
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
|
50 |
+
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
|
51 |
+
max_sfw_score = tag_confidences.get("general", 0)
|
52 |
+
NSFW_flag = None
|
53 |
+
|
54 |
+
if max_nsfw_score > max_sfw_score:
|
55 |
+
NSFW_flag = "NSFWの可能性が高いです"
|
56 |
+
else:
|
57 |
+
NSFW_flag = "SFWの可能性が高いです"
|
58 |
+
|
59 |
+
# 版権キャラクターの可能性を評価
|
60 |
+
character_tags_with_probs = []
|
61 |
+
for i, p in enumerate(prob[4:]):
|
62 |
+
if p >= thresh and i >= len(general_tags):
|
63 |
+
tag_index = i - len(general_tags)
|
64 |
+
if tag_index < len(character_tags):
|
65 |
+
tag_name = character_tags[tag_index]
|
66 |
+
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
|
67 |
+
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
|
68 |
+
|
69 |
+
IP_flag = None
|
70 |
+
if character_tags_with_probs:
|
71 |
+
IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
|
72 |
+
else:
|
73 |
+
IP_flag = "版権キャラクターの可能性が低いと思われます"
|
74 |
+
|
75 |
+
# タグを生成
|
76 |
+
tag_freq = {}
|
77 |
+
undesired_tags = []
|
78 |
+
combined_tags = []
|
79 |
+
general_tag_text = ""
|
80 |
+
character_tag_text = ""
|
81 |
+
remove_underscore = True
|
82 |
+
caption_separator = ", "
|
83 |
+
general_threshold = 0.35
|
84 |
+
character_threshold = 0.35
|
85 |
+
|
86 |
+
for i, p in enumerate(prob[4:]):
|
87 |
+
if i < len(general_tags) and p >= general_threshold:
|
88 |
+
tag_name = general_tags[i]
|
89 |
+
if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
|
90 |
+
tag_name = tag_name.replace("_", " ")
|
91 |
+
|
92 |
+
if tag_name not in undesired_tags:
|
93 |
+
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
94 |
+
general_tag_text += caption_separator + tag_name
|
95 |
+
combined_tags.append(tag_name)
|
96 |
+
elif i >= len(general_tags) and p >= character_threshold:
|
97 |
+
tag_name = character_tags[i - len(general_tags)]
|
98 |
+
if remove_underscore and len(tag_name) > 3:
|
99 |
+
tag_name = tag_name.replace("_", " ")
|
100 |
+
|
101 |
+
if tag_name not in undesired_tags:
|
102 |
+
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
103 |
+
character_tag_text += caption_separator + tag_name
|
104 |
+
combined_tags.append(tag_name)
|
105 |
+
|
106 |
+
# 先頭のカンマを取る
|
107 |
+
if len(general_tag_text) > 0:
|
108 |
+
general_tag_text = general_tag_text[len(caption_separator) :]
|
109 |
+
if len(character_tag_text) > 0:
|
110 |
+
character_tag_text = character_tag_text[len(caption_separator) :]
|
111 |
+
tag_text = caption_separator.join(combined_tags)
|
112 |
+
|
113 |
+
return NSFW_flag, IP_flag, tag_text
|
114 |
+
|
115 |
+
|
116 |
+
class webui:
|
117 |
+
def __init__(self):
|
118 |
+
self.demo = gr.Blocks()
|
119 |
+
|
120 |
+
@spaces.GPU
|
121 |
+
def main(self, image_path, model_id):
|
122 |
+
print("Hugging Faceからモデルをダウンロード中")
|
123 |
+
onnx_path = hf_hub_download(model_id, "model.onnx")
|
124 |
+
csv_path = hf_hub_download(model_id, "selected_tags.csv")
|
125 |
+
|
126 |
+
print("ONNXモデルを実行中")
|
127 |
+
print(f"ONNXモデルのパス: {onnx_path}")
|
128 |
+
|
129 |
+
ort_sess = ort.InferenceSession(onnx_path)
|
130 |
+
|
131 |
+
with open(csv_path, "r", encoding="utf-8") as f:
|
132 |
+
reader = csv.reader(f)
|
133 |
+
header = next(reader)
|
134 |
+
rows = list(reader)
|
135 |
+
assert header == ["tag_id", "name", "category", "count"], f"CSVフォーマット���期待と異なります: {header}"
|
136 |
+
|
137 |
+
rating_tags = [row[1] for row in rows if row[2] == "9"]
|
138 |
+
character_tags = [row[1] for row in rows if row[2] == "4"]
|
139 |
+
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
|
140 |
+
|
141 |
+
NSFW_flag, IP_flag, tag_text = process_image(image_path, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, character_tags, general_tags)
|
142 |
+
return NSFW_flag, IP_flag, tag_text
|
143 |
+
|
144 |
+
|
145 |
+
def launch(self):
|
146 |
+
with self.demo:
|
147 |
+
with gr.Row():
|
148 |
+
with gr.Column():
|
149 |
+
input_image = gr.Image(type='filepath', label="Analysis Image")
|
150 |
+
model_id = gr.Textbox(label="NSFW Flag", value="SmilingWolf/wd-vit-tagger-v3")
|
151 |
+
output_0 = gr.Textbox(label="NSFW Flag")
|
152 |
+
output_1 = gr.Textbox(label="IP Flag")
|
153 |
+
output_2 = gr.Textbox(label="Tags")
|
154 |
+
submit = gr.Button(value="Start Analysis")
|
155 |
+
|
156 |
+
submit.click(
|
157 |
+
self.main,
|
158 |
+
inputs=[input_image, model_id],
|
159 |
+
outputs=[output_0, output_1, output_2]
|
160 |
+
)
|
161 |
+
|
162 |
+
self.demo.launch()
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
ui = webui()
|
166 |
+
ui.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python
|
2 |
+
numpy
|
3 |
+
Pillow
|
4 |
+
onnxruntime
|
5 |
+
onnxruntime-gpu
|