Spaces:
Running
on
Zero
Running
on
Zero
Upload tagger.py
Browse files
tagger.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
from PIL import Image
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
-
import spaces
|
5 |
-
|
6 |
from transformers import (
|
7 |
AutoImageProcessor,
|
8 |
AutoModelForImageClassification,
|
9 |
)
|
10 |
|
|
|
11 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
12 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
13 |
|
@@ -49,6 +49,34 @@ DANBOORU_TO_E621_RATING_MAP = {
|
|
49 |
}
|
50 |
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def to_list(s):
|
53 |
return [x.strip() for x in s.split(",") if not s == ""]
|
54 |
|
@@ -110,7 +138,7 @@ def select_random_character(series: str, character: str):
|
|
110 |
def danbooru_to_e621(dtag, e621_dict):
|
111 |
def d_to_e(match, e621_dict):
|
112 |
dtag = match.group(0)
|
113 |
-
etag = e621_dict.get(dtag
|
114 |
if etag:
|
115 |
return etag
|
116 |
else:
|
@@ -134,7 +162,7 @@ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "
|
|
134 |
|
135 |
e621_dict = danbooru_to_e621_dict
|
136 |
for tag in tags:
|
137 |
-
tag = tag
|
138 |
tag = danbooru_to_e621(tag, e621_dict)
|
139 |
if tag in PEOPLE_TAGS:
|
140 |
people_tags.append(tag)
|
@@ -162,6 +190,7 @@ def translate_prompt(prompt: str = ""):
|
|
162 |
translated_prompt = translator.translate(prompt, src='auto', dest='en').text
|
163 |
return translated_prompt
|
164 |
except Exception as e:
|
|
|
165 |
return prompt
|
166 |
|
167 |
def is_japanese(s):
|
@@ -194,6 +223,7 @@ def translate_prompt_to_ja(prompt: str = ""):
|
|
194 |
translated_prompt = translator.translate(prompt, src='en', dest='ja').text
|
195 |
return translated_prompt
|
196 |
except Exception as e:
|
|
|
197 |
return prompt
|
198 |
|
199 |
def is_japanese(s):
|
@@ -219,7 +249,7 @@ def translate_prompt_to_ja(prompt: str = ""):
|
|
219 |
def tags_to_ja(itag, dict):
|
220 |
def t_to_j(match, dict):
|
221 |
tag = match.group(0)
|
222 |
-
ja = dict.get(tag
|
223 |
if ja:
|
224 |
return ja
|
225 |
else:
|
@@ -238,7 +268,7 @@ def convert_tags_to_ja(input_prompt: str = ""):
|
|
238 |
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
239 |
dict = tags_to_ja_dict
|
240 |
for tag in tags:
|
241 |
-
tag = tag
|
242 |
tag = tags_to_ja(tag, dict)
|
243 |
out_tags.append(tag)
|
244 |
|
@@ -365,7 +395,7 @@ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
|
|
365 |
|
366 |
group_dict = tag_group_dict
|
367 |
for tag in tags:
|
368 |
-
tag = tag
|
369 |
if tag in PEOPLE_TAGS:
|
370 |
people_tags.append(tag)
|
371 |
elif is_necessary(tag, keep_tags, group_dict):
|
@@ -393,7 +423,7 @@ def sort_taglist(tags: list[str]):
|
|
393 |
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
394 |
|
395 |
for tag in tags:
|
396 |
-
tag = tag
|
397 |
if tag in PEOPLE_TAGS:
|
398 |
people_tags.append(tag)
|
399 |
elif tag in rating_set:
|
@@ -494,12 +524,13 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
|
|
494 |
output_series_tag = output_series_list[0]
|
495 |
else:
|
496 |
output_series_tag = ""
|
497 |
-
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
|
498 |
|
499 |
|
500 |
-
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
|
|
|
501 |
if not "Use WD Tagger" in algo and len(algo) != 0:
|
502 |
-
return
|
503 |
return predict_tags(image, general_threshold, character_threshold)
|
504 |
|
505 |
|
|
|
1 |
from PIL import Image
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
+
import spaces
|
|
|
5 |
from transformers import (
|
6 |
AutoImageProcessor,
|
7 |
AutoModelForImageClassification,
|
8 |
)
|
9 |
|
10 |
+
|
11 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
12 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
13 |
|
|
|
49 |
}
|
50 |
|
51 |
|
52 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
53 |
+
kaomojis = [
|
54 |
+
"0_0",
|
55 |
+
"(o)_(o)",
|
56 |
+
"+_+",
|
57 |
+
"+_-",
|
58 |
+
"._.",
|
59 |
+
"<o>_<o>",
|
60 |
+
"<|>_<|>",
|
61 |
+
"=_=",
|
62 |
+
">_<",
|
63 |
+
"3_3",
|
64 |
+
"6_9",
|
65 |
+
">_o",
|
66 |
+
"@_@",
|
67 |
+
"^_^",
|
68 |
+
"o_o",
|
69 |
+
"u_u",
|
70 |
+
"x_x",
|
71 |
+
"|_|",
|
72 |
+
"||_||",
|
73 |
+
]
|
74 |
+
|
75 |
+
|
76 |
+
def replace_underline(x: str):
|
77 |
+
return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
|
78 |
+
|
79 |
+
|
80 |
def to_list(s):
|
81 |
return [x.strip() for x in s.split(",") if not s == ""]
|
82 |
|
|
|
138 |
def danbooru_to_e621(dtag, e621_dict):
|
139 |
def d_to_e(match, e621_dict):
|
140 |
dtag = match.group(0)
|
141 |
+
etag = e621_dict.get(replace_underline(dtag), "")
|
142 |
if etag:
|
143 |
return etag
|
144 |
else:
|
|
|
162 |
|
163 |
e621_dict = danbooru_to_e621_dict
|
164 |
for tag in tags:
|
165 |
+
tag = replace_underline(tag)
|
166 |
tag = danbooru_to_e621(tag, e621_dict)
|
167 |
if tag in PEOPLE_TAGS:
|
168 |
people_tags.append(tag)
|
|
|
190 |
translated_prompt = translator.translate(prompt, src='auto', dest='en').text
|
191 |
return translated_prompt
|
192 |
except Exception as e:
|
193 |
+
print(e)
|
194 |
return prompt
|
195 |
|
196 |
def is_japanese(s):
|
|
|
223 |
translated_prompt = translator.translate(prompt, src='en', dest='ja').text
|
224 |
return translated_prompt
|
225 |
except Exception as e:
|
226 |
+
print(e)
|
227 |
return prompt
|
228 |
|
229 |
def is_japanese(s):
|
|
|
249 |
def tags_to_ja(itag, dict):
|
250 |
def t_to_j(match, dict):
|
251 |
tag = match.group(0)
|
252 |
+
ja = dict.get(replace_underline(tag), "")
|
253 |
if ja:
|
254 |
return ja
|
255 |
else:
|
|
|
268 |
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
269 |
dict = tags_to_ja_dict
|
270 |
for tag in tags:
|
271 |
+
tag = replace_underline(tag)
|
272 |
tag = tags_to_ja(tag, dict)
|
273 |
out_tags.append(tag)
|
274 |
|
|
|
395 |
|
396 |
group_dict = tag_group_dict
|
397 |
for tag in tags:
|
398 |
+
tag = replace_underline(tag)
|
399 |
if tag in PEOPLE_TAGS:
|
400 |
people_tags.append(tag)
|
401 |
elif is_necessary(tag, keep_tags, group_dict):
|
|
|
423 |
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
424 |
|
425 |
for tag in tags:
|
426 |
+
tag = replace_underline(tag)
|
427 |
if tag in PEOPLE_TAGS:
|
428 |
people_tags.append(tag)
|
429 |
elif tag in rating_set:
|
|
|
524 |
output_series_tag = output_series_list[0]
|
525 |
else:
|
526 |
output_series_tag = ""
|
527 |
+
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
|
528 |
|
529 |
|
530 |
+
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
|
531 |
+
character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
|
532 |
if not "Use WD Tagger" in algo and len(algo) != 0:
|
533 |
+
return input_series, input_character, input_tags, gr.update(interactive=True)
|
534 |
return predict_tags(image, general_threshold, character_threshold)
|
535 |
|
536 |
|