hideosnes commited on
Commit
2091d9d
1 Parent(s): b52d6ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +350 -0
app.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import random
4
+ import tempfile
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from diffusers import (
8
+ ControlNetModel,
9
+ StableDiffusionXLControlNetPipeline,
10
+ UNet2DConditionModel,
11
+ EulerDiscreteScheduler,
12
+ )
13
+ import spaces
14
+ import gradio as gr
15
+ from huggingface_hub import hf_hub_download, snapshot_download
16
+ from ip_adapter import IPAdapterXL
17
+ from safetensors.torch import load_file
18
+
19
+ snapshot_download(
20
+ repo_id="h94/IP-Adapter", allow_patterns="sdxl_models/*", local_dir="."
21
+ )
22
+
23
+ # CPU fallback & pipeline-definition
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
27
+
28
+ # load models & scheduler (==>EULER) & CN (==>canny > test what's better!!!)
29
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
30
+ image_encoder_path = "sdxl_models/image_encoder"
31
+ ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
32
+
33
+ controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
34
+ controlnet = ControlNEtModel.from_pretrained(
35
+ controlnet_path, use_safetensors=False, torch_dtype=torch.float16
36
+ ).to(device)
37
+
38
+ # load SDXL lightning >> put Turbo here if fallback to Comfy @Litto
39
+
40
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
41
+ base_model_path,
42
+ controlnet = controlnet,
43
+ torch_dtype=torch.float16,
44
+ variant="fp16",
45
+ add_watermark=False,
46
+ )to(device)
47
+ pipe.set_progress_bar_config(disable=True)
48
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
49
+ pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
50
+ )
51
+ pipe.unet.load_state_dict(
52
+ load_file(
53
+ hf_hub_download(
54
+ "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
55
+ ),
56
+ device="cuda",
57
+ )
58
+ )
59
+
60
+ # load ip-adapter with specific target blocks for style transfer and layout preservation. Should be better than Comfy! Test this!
61
+ # target_blocks=["block"] for original IP-Adapter
62
+ # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
63
+ # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
64
+ ip_model = IPAdapterXL(
65
+ pipe,
66
+ image_encoder_path,
67
+ ip_ckpt,
68
+ device,
69
+ target_blocks=["up_blocks.0.attentions.1"]
70
+ )
71
+
72
+ # Resizing the input image
73
+ # OpenCV goes here!!!
74
+ # Test this with smaller side-no for faster infr
75
+
76
+ def resize_img(
77
+ input_image,
78
+ max_side=1280,
79
+ min_side=1024,
80
+ size=None,
81
+ pad_to_max_side=False,
82
+ mode=Image.BILINEAR,
83
+ base_pixel_number=64,
84
+ ):
85
+ w, h = input_image.size
86
+ if size is not None:
87
+ w_resize_new, h_resize_new = size
88
+ else:
89
+ ratio = min_side / min(h, w)
90
+ w, h = round(ratio * w), round(ratio * h)
91
+ ratio = max_side / max(h, w)
92
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
93
+ w = (round(ratio * w) // base_pixel_number) * base_pixel_number
94
+ w = (round(ratio * h) // base_pixel_number) * base_pixel_number
95
+ nput_image.resize([w_resize_new, h_resize_new], mode)
96
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
97
+
98
+ if pad_to_max_side:
99
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
100
+ offset_x = (max_side - w_resize_new) // 2
101
+ offset_y = (max_side - h_resize_new) // 2
102
+ res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = (
103
+ np.array(input_image)
104
+ )
105
+ input_image = Image.fromarray(res)
106
+ return input_image
107
+
108
+ # expand example images for endpoints --> info an Johannes/Jascha what to expect
109
+
110
+ examples = [
111
+ [
112
+ "./assets/zeichnung1.jpg",
113
+ None,
114
+ "3D model, cute monster, test prompt",
115
+ 1.0,
116
+ 0.0,
117
+ ],
118
+ [
119
+ "./assets/zeichnung2.jpg",
120
+ "./assets/guidance-target.jpg",
121
+ "3D model, cute, kawai, monster, another test prompt",
122
+ 1.0,
123
+ 0.6,
124
+ ],
125
+ ]
126
+
127
+ def run_for_examples(style_image, source_image, prompt, scale, control_scale):
128
+ return create_image(
129
+ image_pil=style_image,
130
+ input_image=source_image,
131
+ prompt=prompt,
132
+ n_prompt="text, watermark, low res, low quality, worst quality, deformed, blurry",
133
+ scale=scale,
134
+ control_scale=control_scale,
135
+ guidance_scale=0.0,
136
+ num_inference_steps=2,
137
+ seed=42,
138
+ target="Load only style blocks",
139
+ neg_content_prompt="",
140
+ neg_content_scale=0,
141
+ )
142
+
143
+ # Main function for image synthesis (input -> run_for_examples)
144
+
145
+ @spaces.GPU(enable_queue=True)
146
+ def create_image(
147
+ image_pil,
148
+ input_image,
149
+ prompt,
150
+ n_prompt,
151
+ scale,
152
+ control_scale,
153
+ guidance_scale,
154
+ num_inference_steps,
155
+ target="Load only style blocks",
156
+ neg_content_prompt=None,
157
+ neg_content_scale=0,
158
+ ):
159
+ seed = random.randint(0, MAX_SEED) if seed == -1 else seed
160
+ if target == "Load original IP-Adapter":
161
+ # target_blocks=["blocks"] for original IP-Adapter
162
+ ip_model = IPAdapterXL(
163
+ pipe, image_encoder_path, ip_ckpt, device, target_blocks=["blocks"]
164
+ )
165
+ elif target == "Load only style blocks":
166
+ # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
167
+ ip_model = IPAdapterXL(
168
+ pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"],
169
+ )
170
+ elif target == "Load style+layout block":
171
+ # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
172
+ ip_model = IPAdapterXL(
173
+ pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"],
174
+ )
175
+
176
+ if input_image is not None:
177
+ input_image = resize_img(input_image, max_side=1024)
178
+ cv_input_image = pil_to_cv2(input_image)
179
+ detected_map = cv2.Canny(cv_input_image, 50, 200)
180
+ canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
181
+ else:
182
+ canny_map = Image.new("RGB", (1024, 1024), color=(255,255,255))
183
+ control_scale = 0
184
+
185
+ if float(control_scale) == 0:
186
+ canny_map = canny_map.resize((1024, 1024))
187
+
188
+ if len(neg_content_prompt) > 0 and neg_content_scale != 0:
189
+ images = ip_model.generate(
190
+ pil_image_image_pil,
191
+ prompt=prompt,
192
+ negative_prompt=n_prompt,
193
+ scale=scale,
194
+ guidance_scale=guidance_scale,
195
+ num_samples=1,
196
+ num_inference_steps=num_inference_steps,
197
+ seed=seed,
198
+ image=canny_map,
199
+ controlnet_conditioning_scale=float(control_scale),
200
+ )
201
+ image = images[0]
202
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile:
203
+ image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True) # check what happens to imgs when this changes!!!
204
+ return Path(tmpfile.name)
205
+
206
+ def pil_to_cv2(image_pil):
207
+ image_np = np.array(image_pil)
208
+ image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
209
+ return image_cv2
210
+
211
+ # Gradio Description & Frontend Stuff for Space (remove this for Endpoint)
212
+ title = r"""
213
+ <h1 align="center">MewMewMew: Simsalabim!</h1>
214
+ """
215
+
216
+ description = r"""
217
+ <b>Let's test this! ARM <3 GoldExtra</b><br>
218
+ <b>SDXL-Lightning && IP-Adapter</b>
219
+ """
220
+
221
+ article = r"""
222
+ Ask Hidéo if something breaks: <a href="mailto:[email protected]">Hidéo's Mail</a>
223
+ """
224
+
225
+ block = gr.Blocks()
226
+ with block:
227
+ #description
228
+ gr.Markdown(title)
229
+ gr.MArkdown(description)
230
+
231
+ with gr.Tabs():
232
+ with gr.Row():
233
+ with gr.Column():
234
+ with gr.Row()
235
+ with gr.Column():
236
+ image_pil = gr.Image(label="Style Image", type="pil")
237
+ with gr.Column():
238
+ prompt = gr.Textbox(
239
+ label="Prompt",
240
+ value="mewmewmew, kitty cats, unicorns, uWu",
241
+ )
242
+
243
+ scale = gr.Slider(
244
+ minimum=0, maximum=2.0, step=0.01, value=1.0, label="Maßstab // scale"
245
+ )
246
+ with gr.Accordion(open=False, label="Für Details erweitern!"):
247
+ target = gr.Radio(
248
+ [
249
+ "Load only style blocks",
250
+ "Load style+layout block",
251
+ "Load original IP-Adapter",
252
+ ],
253
+ value="Load only style blocks",
254
+ label="Modus für IP-Adapter auswählen"
255
+ )
256
+
257
+ with gr.Column():
258
+ src_image_pil = gr.Image(
259
+ label="Guidance Image (optional)", type="pil"
260
+ )
261
+ control_scale = gr.Slider(
262
+ minimum=0, maximum=1.0, step=0.1, value=0.5,
263
+ label="ControlNet-Stärke // control_scale",
264
+ )
265
+ n_prompt = gr.Textbox(
266
+ label="Negative Prompts",
267
+ value=""text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
268
+ )
269
+ neg_content_prompt = gr.Textbox(
270
+ label="Negative Content Prompt (optional)", value=""
271
+ )
272
+ neg_content_scale = gr.Slider(
273
+ minimum=0,
274
+ maximum=1.0,
275
+ step=0.1,
276
+ value=0.5,
277
+ label="Negative Content Stärke // neg_content_scale"
278
+ )
279
+ guidance_scale = gr.Slider(
280
+ minimum=0,
281
+ maximum=10.0,
282
+ step=0.01,
283
+ value=0.0,
284
+ label="guidance-scale"
285
+ )
286
+ num_inference_steps = gr.Slider(
287
+ minimum=2,
288
+ maximum=50.0,
289
+ step=1.0,
290
+ value=2,
291
+ label="Anzahl der Inference Steps (optional) // num_inference_steps"
292
+ )
293
+ seed = gr.Slider(
294
+ minimum=-1,
295
+ maximum=MAX_SEED,
296
+ value=-1,
297
+ step=1,
298
+ label="Seed Value // -1 = random // Seed-Proof=True"
299
+ )
300
+
301
+ generate_button = gr.Button("Simsalabim")
302
+
303
+ with gr.Column():
304
+ generated_image = gr.Image(label="MewMewMagix uWu")
305
+
306
+ inputs = [
307
+ image_pil,
308
+ src_image_pil,
309
+ prompt,
310
+ n_prompt,
311
+ scale,
312
+ control_scale,
313
+ guidance_scale,
314
+ num_inference_steps,
315
+ seed,
316
+ target,
317
+ neg_content_prompt,
318
+ neg_content_scale,
319
+ ]
320
+ outputs = [generated_image]
321
+
322
+ gr.on(
323
+ triggers=[
324
+ prompt.input,
325
+ generate_button.click,
326
+ guidance_scale.input,
327
+ scale.input,
328
+ control_scale.input,
329
+ seed.input,
330
+ ],
331
+ fn=create_image,
332
+ inputs=inputs,
333
+ outputs=outputs,
334
+ show_progress="minimal",
335
+ show_api=False,
336
+ trigger_mode="always_last",
337
+ )
338
+
339
+ gr.Examples(
340
+ examples=examples,
341
+ inputs=[image_pil, src_image_pil, prompt, scale, control_scale],
342
+ fn=run_for_examples,
343
+ outputs=[generated_image],
344
+ cache_examples=True,
345
+ )
346
+
347
+ gr.Markdown(article)
348
+
349
+ block.queue(api_open=False)
350
+ block.launch(show_api=False)