John6666 commited on
Commit
8c0ebd8
β€’
1 Parent(s): bbbad5b

Upload tagger.py

Browse files
Files changed (1) hide show
  1. tagger.py +42 -11
tagger.py CHANGED
@@ -1,13 +1,13 @@
1
  from PIL import Image
2
  import torch
3
  import gradio as gr
4
- import spaces # ZERO GPU
5
-
6
  from transformers import (
7
  AutoImageProcessor,
8
  AutoModelForImageClassification,
9
  )
10
 
 
11
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
12
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
13
 
@@ -49,6 +49,34 @@ DANBOORU_TO_E621_RATING_MAP = {
49
  }
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def to_list(s):
53
  return [x.strip() for x in s.split(",") if not s == ""]
54
 
@@ -110,7 +138,7 @@ def select_random_character(series: str, character: str):
110
  def danbooru_to_e621(dtag, e621_dict):
111
  def d_to_e(match, e621_dict):
112
  dtag = match.group(0)
113
- etag = e621_dict.get(dtag.strip().replace("_", " "), "")
114
  if etag:
115
  return etag
116
  else:
@@ -134,7 +162,7 @@ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "
134
 
135
  e621_dict = danbooru_to_e621_dict
136
  for tag in tags:
137
- tag = tag.strip().replace("_", " ")
138
  tag = danbooru_to_e621(tag, e621_dict)
139
  if tag in PEOPLE_TAGS:
140
  people_tags.append(tag)
@@ -162,6 +190,7 @@ def translate_prompt(prompt: str = ""):
162
  translated_prompt = translator.translate(prompt, src='auto', dest='en').text
163
  return translated_prompt
164
  except Exception as e:
 
165
  return prompt
166
 
167
  def is_japanese(s):
@@ -194,6 +223,7 @@ def translate_prompt_to_ja(prompt: str = ""):
194
  translated_prompt = translator.translate(prompt, src='en', dest='ja').text
195
  return translated_prompt
196
  except Exception as e:
 
197
  return prompt
198
 
199
  def is_japanese(s):
@@ -219,7 +249,7 @@ def translate_prompt_to_ja(prompt: str = ""):
219
  def tags_to_ja(itag, dict):
220
  def t_to_j(match, dict):
221
  tag = match.group(0)
222
- ja = dict.get(tag.strip().replace("_", " "), "")
223
  if ja:
224
  return ja
225
  else:
@@ -238,7 +268,7 @@ def convert_tags_to_ja(input_prompt: str = ""):
238
  tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
239
  dict = tags_to_ja_dict
240
  for tag in tags:
241
- tag = tag.strip().replace("_", " ")
242
  tag = tags_to_ja(tag, dict)
243
  out_tags.append(tag)
244
 
@@ -365,7 +395,7 @@ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
365
 
366
  group_dict = tag_group_dict
367
  for tag in tags:
368
- tag = tag.strip().replace("_", " ")
369
  if tag in PEOPLE_TAGS:
370
  people_tags.append(tag)
371
  elif is_necessary(tag, keep_tags, group_dict):
@@ -393,7 +423,7 @@ def sort_taglist(tags: list[str]):
393
  rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
394
 
395
  for tag in tags:
396
- tag = tag.strip().replace("_", " ")
397
  if tag in PEOPLE_TAGS:
398
  people_tags.append(tag)
399
  elif tag in rating_set:
@@ -494,12 +524,13 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
494
  output_series_tag = output_series_list[0]
495
  else:
496
  output_series_tag = ""
497
- return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
498
 
499
 
500
- def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
 
501
  if not "Use WD Tagger" in algo and len(algo) != 0:
502
- return "", "", input_tags, gr.update(interactive=True),
503
  return predict_tags(image, general_threshold, character_threshold)
504
 
505
 
 
1
  from PIL import Image
2
  import torch
3
  import gradio as gr
4
+ import spaces
 
5
  from transformers import (
6
  AutoImageProcessor,
7
  AutoModelForImageClassification,
8
  )
9
 
10
+
11
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
12
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
13
 
 
49
  }
50
 
51
 
52
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
53
+ kaomojis = [
54
+ "0_0",
55
+ "(o)_(o)",
56
+ "+_+",
57
+ "+_-",
58
+ "._.",
59
+ "<o>_<o>",
60
+ "<|>_<|>",
61
+ "=_=",
62
+ ">_<",
63
+ "3_3",
64
+ "6_9",
65
+ ">_o",
66
+ "@_@",
67
+ "^_^",
68
+ "o_o",
69
+ "u_u",
70
+ "x_x",
71
+ "|_|",
72
+ "||_||",
73
+ ]
74
+
75
+
76
+ def replace_underline(x: str):
77
+ return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
78
+
79
+
80
  def to_list(s):
81
  return [x.strip() for x in s.split(",") if not s == ""]
82
 
 
138
  def danbooru_to_e621(dtag, e621_dict):
139
  def d_to_e(match, e621_dict):
140
  dtag = match.group(0)
141
+ etag = e621_dict.get(replace_underline(dtag), "")
142
  if etag:
143
  return etag
144
  else:
 
162
 
163
  e621_dict = danbooru_to_e621_dict
164
  for tag in tags:
165
+ tag = replace_underline(tag)
166
  tag = danbooru_to_e621(tag, e621_dict)
167
  if tag in PEOPLE_TAGS:
168
  people_tags.append(tag)
 
190
  translated_prompt = translator.translate(prompt, src='auto', dest='en').text
191
  return translated_prompt
192
  except Exception as e:
193
+ print(e)
194
  return prompt
195
 
196
  def is_japanese(s):
 
223
  translated_prompt = translator.translate(prompt, src='en', dest='ja').text
224
  return translated_prompt
225
  except Exception as e:
226
+ print(e)
227
  return prompt
228
 
229
  def is_japanese(s):
 
249
  def tags_to_ja(itag, dict):
250
  def t_to_j(match, dict):
251
  tag = match.group(0)
252
+ ja = dict.get(replace_underline(tag), "")
253
  if ja:
254
  return ja
255
  else:
 
268
  tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
269
  dict = tags_to_ja_dict
270
  for tag in tags:
271
+ tag = replace_underline(tag)
272
  tag = tags_to_ja(tag, dict)
273
  out_tags.append(tag)
274
 
 
395
 
396
  group_dict = tag_group_dict
397
  for tag in tags:
398
+ tag = replace_underline(tag)
399
  if tag in PEOPLE_TAGS:
400
  people_tags.append(tag)
401
  elif is_necessary(tag, keep_tags, group_dict):
 
423
  rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
424
 
425
  for tag in tags:
426
+ tag = replace_underline(tag)
427
  if tag in PEOPLE_TAGS:
428
  people_tags.append(tag)
429
  elif tag in rating_set:
 
524
  output_series_tag = output_series_list[0]
525
  else:
526
  output_series_tag = ""
527
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
528
 
529
 
530
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
531
+ character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
532
  if not "Use WD Tagger" in algo and len(algo) != 0:
533
+ return input_series, input_character, input_tags, gr.update(interactive=True)
534
  return predict_tags(image, general_threshold, character_threshold)
535
 
536