Update app.py
Browse files
app.py
CHANGED
@@ -227,6 +227,33 @@ def run_rmbg(img, sigma=0.0):
|
|
227 |
result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
|
228 |
return result.clip(0, 255).astype(np.uint8), alpha
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
@torch.inference_mode()
|
232 |
def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
@@ -256,6 +283,7 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
|
|
256 |
|
257 |
rng = torch.Generator(device=device).manual_seed(int(seed))
|
258 |
|
|
|
259 |
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
260 |
|
261 |
concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
|
@@ -277,7 +305,8 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
|
|
277 |
cross_attention_kwargs={'concat_conds': concat_conds},
|
278 |
).images.to(vae.dtype) / vae.config.scaling_factor
|
279 |
else:
|
280 |
-
bg =
|
|
|
281 |
bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
|
282 |
bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
|
283 |
latents = i2i_pipe(
|
@@ -333,10 +362,11 @@ def process(input_fg, prompt, image_width, image_height, num_samples, seed, step
|
|
333 |
return pytorch2numpy(pixels)
|
334 |
|
335 |
|
336 |
-
@spaces.GPU
|
337 |
@torch.inference_mode()
|
338 |
def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
339 |
-
input_fg, matting = run_rmbg(input_fg)
|
|
|
340 |
results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
|
341 |
return input_fg, results
|
342 |
|
@@ -378,14 +408,12 @@ class BGSource(Enum):
|
|
378 |
block = gr.Blocks().queue()
|
379 |
with block:
|
380 |
with gr.Row():
|
381 |
-
gr.Markdown("##
|
382 |
-
with gr.Row():
|
383 |
-
gr.Markdown("See also https://github.com/lllyasviel/IC-Light for background-conditioned model and normal estimation")
|
384 |
with gr.Row():
|
385 |
with gr.Column():
|
386 |
with gr.Row():
|
387 |
-
input_fg = gr.Image(sources='upload', type="numpy", label="Image",
|
388 |
-
output_bg = gr.Image(type="numpy", label="Preprocessed Foreground"
|
389 |
prompt = gr.Textbox(label="Prompt")
|
390 |
bg_source = gr.Radio(choices=[e.value for e in BGSource],
|
391 |
value=BGSource.NONE.value,
|
@@ -400,8 +428,8 @@ with block:
|
|
400 |
seed = gr.Number(label="Seed", value=12345, precision=0)
|
401 |
|
402 |
with gr.Row():
|
403 |
-
image_width = gr.Slider(label="Image Width", minimum=256, maximum=
|
404 |
-
image_height = gr.Slider(label="Image Height", minimum=256, maximum=
|
405 |
|
406 |
with gr.Accordion("Advanced options", open=False):
|
407 |
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
|
@@ -415,15 +443,7 @@ with block:
|
|
415 |
result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
|
416 |
with gr.Row():
|
417 |
dummy_image_for_outputs = gr.Image(visible=False, label='Result')
|
418 |
-
|
419 |
-
fn=lambda *args: [[args[-1]], "imgs/dummy.png"],
|
420 |
-
examples=db_examples.foreground_conditioned_examples,
|
421 |
-
inputs=[
|
422 |
-
input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
|
423 |
-
],
|
424 |
-
outputs=[result_gallery, output_bg],
|
425 |
-
run_on_click=True, examples_per_page=1024
|
426 |
-
)
|
427 |
ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
|
428 |
relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
|
429 |
example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
|
|
|
227 |
result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
|
228 |
return result.clip(0, 255).astype(np.uint8), alpha
|
229 |
|
230 |
+
@torch.inference_mode()
|
231 |
+
def merge_alpha(img, sigma=0.0):
|
232 |
+
if img is None:
|
233 |
+
return None
|
234 |
+
|
235 |
+
if len(img.shape) == 2:
|
236 |
+
img = np.stack((img,)*3, axis=-1)
|
237 |
+
|
238 |
+
H, W, C = img.shape
|
239 |
+
print(f"img.shape: {img.shape}")
|
240 |
+
|
241 |
+
if C == 3:
|
242 |
+
img, _ = run_rmbg(img)
|
243 |
+
return img
|
244 |
+
elif C == 4:
|
245 |
+
rgb = img[:, :, :3].astype(np.float32)
|
246 |
+
alpha = img[:, :, 3].astype(np.float32) / 255.0
|
247 |
+
|
248 |
+
result = rgb * alpha[:, :, np.newaxis] + 255 * (1 - alpha[:, :, np.newaxis])
|
249 |
+
|
250 |
+
if sigma != 0:
|
251 |
+
result += sigma * alpha[:, :, np.newaxis]
|
252 |
+
|
253 |
+
return np.clip(result, 0, 255).astype(np.uint8)
|
254 |
+
else:
|
255 |
+
raise ValueError(f"Unexpected number of channels: {C}")
|
256 |
+
|
257 |
|
258 |
@torch.inference_mode()
|
259 |
def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
|
|
283 |
|
284 |
rng = torch.Generator(device=device).manual_seed(int(seed))
|
285 |
|
286 |
+
#fg = input_fg
|
287 |
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
288 |
|
289 |
concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
|
|
|
305 |
cross_attention_kwargs={'concat_conds': concat_conds},
|
306 |
).images.to(vae.dtype) / vae.config.scaling_factor
|
307 |
else:
|
308 |
+
#bg = input_bg
|
309 |
+
bg = resize_and_center_crop(input_bg, image_width, image_height)
|
310 |
bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
|
311 |
bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
|
312 |
latents = i2i_pipe(
|
|
|
362 |
return pytorch2numpy(pixels)
|
363 |
|
364 |
|
365 |
+
@spaces.GPU(duration=240)
|
366 |
@torch.inference_mode()
|
367 |
def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
368 |
+
#input_fg, matting = run_rmbg(input_fg)
|
369 |
+
input_fg = merge_alpha(input_fg)
|
370 |
results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
|
371 |
return input_fg, results
|
372 |
|
|
|
408 |
block = gr.Blocks().queue()
|
409 |
with block:
|
410 |
with gr.Row():
|
411 |
+
gr.Markdown("##ICLight without mask")
|
|
|
|
|
412 |
with gr.Row():
|
413 |
with gr.Column():
|
414 |
with gr.Row():
|
415 |
+
input_fg = gr.Image(sources='upload', type="numpy", label="Image", image_mode='RGBA')
|
416 |
+
output_bg = gr.Image(type="numpy", label="Preprocessed Foreground")
|
417 |
prompt = gr.Textbox(label="Prompt")
|
418 |
bg_source = gr.Radio(choices=[e.value for e in BGSource],
|
419 |
value=BGSource.NONE.value,
|
|
|
428 |
seed = gr.Number(label="Seed", value=12345, precision=0)
|
429 |
|
430 |
with gr.Row():
|
431 |
+
image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=512, step=64)
|
432 |
+
image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=640, step=64)
|
433 |
|
434 |
with gr.Accordion("Advanced options", open=False):
|
435 |
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
|
|
|
443 |
result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
|
444 |
with gr.Row():
|
445 |
dummy_image_for_outputs = gr.Image(visible=False, label='Result')
|
446 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
|
448 |
relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
|
449 |
example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
|