Layout-Control / app.py
ysmao's picture
update rumway stable diffusion to realistic vision model
237c7d9 verified
import torch
torch.jit.script = lambda f: f
import spaces
import numpy as np
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
import gradio as gr
from huggingface_hub import hf_hub_download
from annotator.util import resize_image, HWC3
from annotator.midas import DepthDetector
from annotator.dsine_local import NormalDetector
from annotator.upernet import SegmDetector
controlnet_checkpoint = "kujiale-ai/controlnet-layout"
# Initialize pipeline
controlnet = ControlNetModel.from_pretrained(
controlnet_checkpoint,
subfolder="control_v1_sd15_layout_fp16",
torch_dtype=torch.float16,
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stablediffusionapi/realistic-vision-v51", controlnet=controlnet, torch_dtype=torch.float16
).to("cuda")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
apply_depth = DepthDetector()
apply_normal = NormalDetector(hf_hub_download("camenduru/DSINE", filename="dsine.pt"))
apply_segm = SegmDetector()
layout_examples = [
[
"examples/layout_input.jpg",
"A modern bedroom",
"examples/layout_output.jpg",
],
[
"examples/living_and_dining_room_input.jpg",
"A modern living and dining room",
"examples/living_and_dining_room_output.jpg",
],
[
"examples/living_room_input.png",
"A living room",
"examples/living_room_output.jpg",
],
[
"examples/kitchen_input.jpg",
"A furnished kitchen",
"examples/kitchen_output.jpg",
],
]
@spaces.GPU(duration=20)
def generate(
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
steps,
strength,
guidance_scale,
seed,
):
color_image = resize_image(HWC3(input_image), image_resolution)
# set seed
np.random.seed(seed)
torch.manual_seed(seed)
with torch.no_grad():
depth_image = apply_depth(color_image)
normal_image = apply_normal(color_image)
segm_image = apply_segm(color_image)
# Prepare Layout Control Image
depth_image = np.array(depth_image, dtype=np.float32) / 255.0
depth_image = torch.from_numpy(depth_image[:, :, None])[None].permute(
0, 3, 1, 2
)
normal_image = np.array(normal_image, dtype=np.float32)
normal_image = normal_image / 127.5 - 1.0
normal_image = torch.from_numpy(normal_image)[None].permute(0, 3, 1, 2)
segm_image = np.array(segm_image, dtype=np.float32) / 255.0
segm_image = torch.from_numpy(segm_image)[None].permute(0, 3, 1, 2)
control_image = torch.cat([depth_image, normal_image, segm_image], dim=1)
generator = torch.Generator(device="cuda").manual_seed(seed)
images = pipe(
prompt + a_prompt,
negative_prompt=n_prompt,
num_images_per_prompt=num_samples,
num_inference_steps=steps,
image=control_image,
generator=generator,
guidance_scale=float(guidance_scale),
controlnet_conditioning_scale=float(strength),
).images
return images
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## KuJiaLe Layout ControlNet Demo")
with gr.Row():
gr.Markdown(
"### Checkout our released model at [kujiale-ai/controlnet-layout](https://huggingface.co/kujiale-ai/controlnet-layout)"
)
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(
sources="upload", type="numpy", label="Input Image", height=512
)
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(
label="Images", minimum=1, maximum=2, value=1, step=1
)
image_resolution = gr.Slider(
label="Image Resolution",
minimum=512,
maximum=768,
value=768,
step=64,
)
strength = gr.Slider(
label="Control Strength",
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.1,
)
steps = gr.Slider(
label="Steps", minimum=1, maximum=50, value=25, step=1
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.1,
maximum=20.0,
value=7.5,
step=0.1,
)
seed = gr.Slider(
label="Seed", minimum=-1, maximum=2147483647, value=1, step=1
)
a_prompt = gr.Textbox(
label="Added Prompt", value="best quality, extremely detailed"
)
n_prompt = gr.Textbox(
label="Negative Prompt",
value="longbody, lowres, bad anatomy, human, extra digit, fewer digits, cropped, worst quality, low quality",
)
with gr.Column():
image_gallery = gr.Gallery(
label="Output",
show_label=False,
elem_id="gallery",
height=512,
object_fit="contain",
)
with gr.Row():
dummy_image_for_outputs = gr.Image(visible=False, label="Result")
gr.Examples(
fn=lambda *args: [[args[-1]], args[-2]],
examples=layout_examples,
inputs=[input_image, prompt, dummy_image_for_outputs],
outputs=[image_gallery, prompt],
run_on_click=True,
examples_per_page=1024,
)
ips = [
input_image,
prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
steps,
strength,
guidance_scale,
seed,
]
run_button.click(fn=generate, inputs=ips, outputs=[image_gallery])
block.launch(server_name="0.0.0.0")