Spaces:
Runtime error
Runtime error
File size: 5,370 Bytes
fd1c028 092fcaa fd1c028 092fcaa fd1c028 092fcaa fd1c028 092fcaa d639c7d fd1c028 092fcaa fd1c028 092fcaa d639c7d 092fcaa d639c7d 092fcaa d639c7d 092fcaa d639c7d 092fcaa fd1c028 d639c7d 092fcaa d639c7d 092fcaa d639c7d 092fcaa fd1c028 33739c5 47782a8 33739c5 fd1c028 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.utils import load_image
from transformers import DPTImageProcessor, DPTForDepthEstimation
import torch
import mediapy
import sa_handler
import pipeline_calls
# init models
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
share_layer_norm=False,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,
)
handler = sa_handler.Handler(pipeline)
handler.register(sa_args, )
# run ControlNet depth with StyleAligned
def style_aligned_controlnet(ref_style_prompt, depth_map, ref_image, img_generation_prompt):
if depth_map == True:
image = load_image(ref_image)
depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
else:
depth_image = load_image(ref_image).resize((1024, 1024))
controlnet_conditioning_scale = 0.8
num_images_per_prompt = 3 # adjust according to VRAM size
latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
images = pipeline_calls.controlnet_call(pipeline, [ref_style_prompt, img_generation_prompt],
image=depth_image,
num_inference_steps=50,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_images_per_prompt=num_images_per_prompt,
latents=latents)
#mediapy.show_images([images[0], depth_image2] + images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))])
return [images[0], depth_image] + images[1:], gr.Image(value=images[0], visible=True)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(variant='panel'):
ref_style_prompt = gr.Textbox(
label='Reference style prompt',
info="Enter a Prompt to generate the reference image", placeholder='a poster in <style name> style'
)
depth_map = gr.Checkbox(label='Depth-map',)
ref_style_image = gr.Image(visible=False, label='Reference style image')
with gr.Column(variant='panel'):
ref_image = gr.Image(label="Upload the reference image",
type='filepath' )
img_generation_prompt = gr.Textbox(
label='ControlNet Prompt',
info="Enter a Prompt to generate images using ControlNet and Style-aligned",
)
btn = gr.Button("Generate", size='sm')
gallery = gr.Gallery(label="Style-Aligned ControlNet - Generated images",
elem_id="gallery",
columns=5,
rows=1,
object_fit="contain",
height="auto",
)
btn.click(fn=style_aligned_controlnet,
inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt],
outputs=[gallery, ref_style_image],
api_name="style_aligned_controlnet")
gr.Examples(
examples=[
['A poster in a papercut art style.', False, 'example_image/A.png', 'Letter A in a papercut art style.'],
['A poster in a papercut art style.', True, 'example_image/camel.jpg', 'A camel in a papercut art style.'],
['A couple sitting a wooden bench, in clay animation, claymation style.', False, 'example_image/train.jpg', 'A train in clay animation, claymation style.'],
['A couple sitting a wooden bench, in clay animation, claymation style.', True, 'example_image/sun.png', 'Sun in clay animation, claymation style.'],
['A bull in a low-poly, colorful origami style.', True, 'example_image/whale.png', 'A whale in a low-poly, colorful origami style.'],
['A house in a painterly, digital illustration style.', True, 'example_image/camel.jpg', 'A camel in a painterly, digital illustration style.'],
],
inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt],
outputs=[gallery, ref_style_image],
fn=style_aligned_controlnet,
)
demo.launch()
|