Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LCMScheduler | |
from diffusers.schedulers import TCDScheduler | |
from PIL import Image | |
SAFETY_CHECKER = True | |
checkpoints = { | |
"2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0], | |
"4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0], | |
"8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0], | |
"16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0], | |
"Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5], | |
"Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5], | |
"Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5], | |
"LCM-Like LoRA": [ | |
"pcm_{}_lcmlike_lora_converted.safetensors", | |
4, | |
0.0, | |
], | |
} | |
loaded = None | |
if torch.cuda.is_available(): | |
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
).to("cuda") | |
pipe_sd15 = StableDiffusionPipeline.from_pretrained( | |
"ZeroCool94/stable-diffusion-v1-5", torch_dtype=torch.float16, | |
).to("cuda") | |
if SAFETY_CHECKER: | |
from safety_checker import StableDiffusionSafetyChecker | |
from transformers import CLIPFeatureExtractor | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
).to("cuda") | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
def check_nsfw_images( | |
images: list[Image.Image], | |
) -> tuple[list[Image.Image], list[bool]]: | |
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda") | |
has_nsfw_concepts = safety_checker( | |
images=[images], clip_input=safety_checker_input.pixel_values.to("cuda") | |
) | |
return images, has_nsfw_concepts | |
def generate_image( | |
prompt, | |
ckpt, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
mode="sdxl", | |
): | |
global loaded | |
checkpoint = checkpoints[ckpt][0].format(mode) | |
guidance_scale = checkpoints[ckpt][2] | |
pipe = pipe_sdxl if mode == "sdxl" else pipe_sd15 | |
if loaded != (ckpt + mode): | |
pipe.load_lora_weights( | |
"wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder=mode | |
) | |
loaded = ckpt + mode | |
if ckpt == "LCM-Like LoRA": | |
pipe.scheduler = LCMScheduler() | |
else: | |
pipe.scheduler = TCDScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
timestep_spacing="trailing", | |
) | |
results = pipe( | |
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale | |
) | |
if SAFETY_CHECKER: | |
images, has_nsfw_concepts = check_nsfw_images(results.images) | |
if any(has_nsfw_concepts): | |
gr.Warning("NSFW content detected.") | |
return Image.new("RGB", (512, 512)) | |
return images[0] | |
return results.images[0] | |
def update_steps(ckpt): | |
num_inference_steps = checkpoints[ckpt][1] | |
if ckpt == "LCM-Like LoRA": | |
return gr.update(interactive=True, value=num_inference_steps) | |
return gr.update(interactive=False, value=num_inference_steps) | |
css = """ | |
.gradio-container { | |
max-width: 60rem !important; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# Phased Consistency Model | |
Phased Consistency Model (PCM) is an image generation technique that addresses the limitations of the Latent Consistency Model (LCM) in high-resolution and text-conditioned image generation. | |
PCM outperforms LCM across various generation settings and achieves state-of-the-art results in both image and video generation. | |
[[paper](https://huggingface.co/papers/2405.18407)] [[arXiv](https://arxiv.org/abs/2405.18407)] [[code](https://github.com/G-U-N/Phased-Consistency-Model)] [[project page](https://g-u-n.github.io/projects/pcm)] | |
""" | |
) | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", scale=8) | |
ckpt = gr.Dropdown( | |
label="Select inference steps", | |
choices=list(checkpoints.keys()), | |
value="4-Step", | |
) | |
steps = gr.Slider( | |
label="Number of Inference Steps", | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=4, | |
interactive=False, | |
) | |
ckpt.change( | |
fn=update_steps, | |
inputs=[ckpt], | |
outputs=[steps], | |
queue=False, | |
show_progress=False, | |
) | |
submit_sdxl = gr.Button("Run on SDXL", scale=1) | |
submit_sd15 = gr.Button("Run on SD15", scale=1) | |
img = gr.Image(label="PCM Image") | |
gr.Examples( | |
examples=[ | |
[" astronaut walking on the moon", "4-Step", 4], | |
[ | |
"Photo of a dramatic cliffside lighthouse in a storm, waves crashing, symbol of guidance and resilience.", | |
"8-Step", | |
8, | |
], | |
[ | |
"Vincent vangogh style, painting, a boy, clouds in the sky", | |
"Normal CFG 4-Step", | |
4, | |
], | |
[ | |
"Echoes of a forgotten song drift across the moonlit sea, where a ghost ship sails, its spectral crew bound to an eternal quest for redemption.", | |
"4-Step", | |
4, | |
], | |
[ | |
"Roger rabbit as a real person, photorealistic, cinematic.", | |
"16-Step", | |
16, | |
], | |
[ | |
"tanding tall amidst the ruins, a stone golem awakens, vines and flowers sprouting from the crevices in its body.", | |
"LCM-Like LoRA", | |
4, | |
], | |
], | |
inputs=[prompt, ckpt, steps], | |
outputs=[img], | |
fn=generate_image, | |
cache_examples="lazy", | |
) | |
gr.on( | |
fn=generate_image, | |
triggers=[ckpt.change, prompt.submit, submit_sdxl.click], | |
inputs=[prompt, ckpt, steps], | |
outputs=[img], | |
) | |
gr.on( | |
fn=lambda *args: generate_image(*args, mode="sd15"), | |
triggers=[submit_sd15.click], | |
inputs=[prompt, ckpt, steps], | |
outputs=[img], | |
) | |
demo.queue(api_open=False).launch(show_api=False) | |