p1atdev commited on
Commit
f357c72
1 Parent(s): 0531bbd

initial commit

Browse files
Files changed (2) hide show
  1. app.py +181 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch
4
+
5
+ import gradio as gr
6
+
7
+ from transformers import (
8
+ AutoImageProcessor,
9
+ AutoModelForImageClassification,
10
+ )
11
+
12
+ # import spaces # ZERO GPU
13
+
14
+ MODEL_NAME = "p1atdev/wd-swinv2-tagger-v3-hf"
15
+
16
+ MODEL_NAMES = {
17
+ "swinv2-v3": "p1atdev/wd-swinv2-tagger-v3-hf",
18
+ }
19
+
20
+ model = AutoModelForImageClassification.from_pretrained(
21
+ MODEL_NAME,
22
+ )
23
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
24
+
25
+
26
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
27
+ return (
28
+ [f"1{noun}"]
29
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
30
+ + [f"{maximum+1}+{noun}s"]
31
+ )
32
+
33
+
34
+ PEOPLE_TAGS = (
35
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
36
+ )
37
+ RATING_MAP = {
38
+ "general": "safe",
39
+ "sensitive": "sensitive",
40
+ "questionable": "nsfw",
41
+ "explicit": "explicit, nsfw",
42
+ }
43
+
44
+ DESCRIPTION_MD = """
45
+ # WD Tagger with 🤗 transformers
46
+ Currently supports the following model(s):
47
+ - [p1atdev/wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf)
48
+
49
+ """.strip()
50
+
51
+
52
+ def postprocess_results(
53
+ results: dict[str, float], general_threshold: float, character_threshold: float
54
+ ):
55
+ results = {
56
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
57
+ }
58
+
59
+ rating = {}
60
+ character = {}
61
+ general = {}
62
+
63
+ for k, v in results.items():
64
+ if k.startswith("rating:"):
65
+ rating[k.replace("rating:", "")] = v
66
+ continue
67
+ elif k.startswith("character:"):
68
+ character[k.replace("character:", "")] = v
69
+ continue
70
+
71
+ general[k] = v
72
+
73
+ character = {k: v for k, v in character.items() if v >= character_threshold}
74
+ general = {k: v for k, v in general.items() if v >= general_threshold}
75
+
76
+ return rating, character, general
77
+
78
+
79
+ def animagine_prompt(rating: list[str], character: list[str], general: list[str]):
80
+ people_tags: list[str] = []
81
+ other_tags: list[str] = []
82
+ rating_tag = RATING_MAP[rating[0]]
83
+
84
+ for tag in general:
85
+ if tag in PEOPLE_TAGS:
86
+ people_tags.append(tag)
87
+ else:
88
+ other_tags.append(tag)
89
+
90
+ all_tags = people_tags + character + other_tags + [rating_tag]
91
+
92
+ return ", ".join(all_tags)
93
+
94
+
95
+ # @spaces.GPU
96
+ @torch.no_grad()
97
+ def predict_tags(
98
+ image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8
99
+ ):
100
+ inputs = processor.preprocess(image, return_tensors="pt")
101
+
102
+ outputs = model(**inputs.to(model.device, model.dtype))
103
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
104
+
105
+ # get probabilities
106
+ results = {
107
+ model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
108
+ }
109
+
110
+ # rating, character, general
111
+ rating, character, general = postprocess_results(
112
+ results, general_threshold, character_threshold
113
+ )
114
+
115
+ prompt = animagine_prompt(
116
+ list(rating.keys()), list(character.keys()), list(general.keys())
117
+ )
118
+
119
+ return rating, character, general, prompt
120
+
121
+
122
+ def demo():
123
+ with gr.Blocks() as ui:
124
+ gr.Markdown(DESCRIPTION_MD)
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ input_image = gr.Image(label="Input image", type="pil")
129
+
130
+ with gr.Group():
131
+ general_threshold = gr.Slider(
132
+ label="Threshold",
133
+ minimum=0.0,
134
+ maximum=1.0,
135
+ value=0.3,
136
+ step=0.01,
137
+ interactive=True,
138
+ )
139
+ character_threshold = gr.Slider(
140
+ label="Character threshold",
141
+ minimum=0.0,
142
+ maximum=1.0,
143
+ value=0.8,
144
+ step=0.01,
145
+ interactive=True,
146
+ )
147
+
148
+ _model_radio = gr.Radio(
149
+ choices=list(MODEL_NAMES.keys()),
150
+ label="Model",
151
+ value=list(MODEL_NAMES.keys())[0],
152
+ )
153
+
154
+ start_btn = gr.Button(value="Start", variant="primary")
155
+
156
+ with gr.Column():
157
+ prompt_text = gr.Text(label="Prompt")
158
+
159
+ rating_tags_label = gr.Label(label="Rating tags")
160
+ character_tags_label = gr.Label(label="Character tags")
161
+ general_tags_label = gr.Label(label="General tags")
162
+
163
+ start_btn.click(
164
+ predict_tags,
165
+ inputs=[input_image, general_threshold, character_threshold],
166
+ outputs=[
167
+ rating_tags_label,
168
+ character_tags_label,
169
+ general_tags_label,
170
+ prompt_text,
171
+ ],
172
+ )
173
+
174
+ ui.launch(
175
+ # debug=True,
176
+ # share=True,
177
+ )
178
+
179
+
180
+ if __name__ == "__main__":
181
+ demo()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ accelerate
4
+ transformers==4.38.2
5
+ spaces