import cv2 import einops import gradio as gr import numpy as np import torch from pytorch_lightning import seed_everything from util import resize_image, HWC3, apply_canny from ldm.models.diffusion.ddim import DDIMSampler from annotator.openpose import apply_openpose from cldm.model import create_model, load_state_dict from huggingface_hub import hf_hub_url, cached_download REPO_ID = "lllyasviel/ControlNet" canny_checkpoint = "models/control_sd15_canny.pth" scribble_checkpoint = "models/control_sd15_scribble.pth" pose_checkpoint = "models/control_sd15_openpose.pth" pose_model = create_model('./models/cldm_v15.yaml').cpu() pose_model.load_state_dict(load_state_dict(cached_download( hf_hub_url(REPO_ID, pose_checkpoint) ), location='cuda')) pose_model = pose_model.cuda() ddim_sampler_pose = DDIMSampler(pose_model) scribble_model = create_model('./models/cldm_v15.yaml').cpu() scribble_model.load_state_dict(load_state_dict(cached_download( hf_hub_url(REPO_ID, scribble_checkpoint) ), location='cuda')) scribble_model = canny_model.cuda() ddim_sampler_scribble = DDIMSampler(scribble_model) def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold): # TODO: Add other control tasks if input_control == "Scribble": return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta) else: return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, image_resolution, ddim_steps, scale, seed, eta) def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta): with torch.no_grad(): img = resize_image(HWC3(input_image), image_resolution) H, W, C = img.shape detected_map = np.zeros_like(img, dtype=np.uint8) detected_map[np.min(img, axis=2) < 127] = 255 control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() seed_everything(seed) cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} un_cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) samples, intermediates = ddim_sampler_scribble.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) x_samples = scribble_model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] return [255 - detected_map] + results def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta): with torch.no_grad(): input_image = HWC3(input_image) detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution)) detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} un_cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) samples, intermediates = ddim_sampler_pose.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) x_samples = pose_model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] return [detected_map] + results def create_canvas(w, h): new_control_options = ["Interactive Scribble"] return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 block = gr.Blocks().queue() control_task_list = [ "Scribble", "Pose" ] with block: gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models") gr.HTML('''

This is an unofficial demo for ControlNet, which is a neural network structure to control diffusion models by adding extra conditions such as canny edge detection. The demo is based on the Github implementation.

''') gr.HTML("

You can duplicate this Space to run it privately without a queue and load additional checkpoints. : Duplicate Space Open in Colab

") with gr.Row(): with gr.Column(): input_image = gr.Image(source='upload', type="numpy") input_control = gr.Dropdown(control_task_list, value="Scribble", label="Control Task") prompt = gr.Textbox(label="Prompt") run_button = gr.Button(label="Run") with gr.Accordion("Advanced options", open=False): num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True) eta = gr.Number(label="eta (DDIM)", value=0.0) a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality') with gr.Column(): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') ips = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta] run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) examples_list = [ [ "turtle.png", "turtle", "Scribble", "best quality, extremely detailed", 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality', 1, 512, 20, 9.0, 123490213, 0.0 ] ] examples = gr.Examples(examples=examples_list,inputs = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold], outputs = [result_gallery], cache_examples = True, fn = process) block.launch(debug = True)