Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,315 Bytes
ef187eb e7915f0 0cffd40 ef187eb 11fa80e 63b6eaf 2b0f02c 11fa80e 0cffd40 8b1e96d 0cffd40 8b1e96d 0cccf69 8b1e96d ec35e66 4efab5c ec35e66 4efab5c 8b1e96d e7915f0 8b1e96d f286ae5 4b64a91 8b1e96d 9b38787 3a2b9b2 8b1e96d 9b38787 11fa80e 8b1e96d 3494613 6380dba 8b1e96d 3819ced 67399b5 1462211 fee8445 1462211 ef187eb 1462211 8b1e96d 0cffd40 8b3ca8d 0cffd40 8b1e96d 0cffd40 4efab5c 3639a4a 9b38787 8b1e96d 0cffd40 9b38787 fe16630 8b1e96d 8b3ca8d fe16630 8b3ca8d 8b1e96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
import spaces
from PIL import Image
import requests
from translatepy import Translator
translator = Translator()
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "tianweiy/DMD2"
checkpoints = {
"1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
"4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
}
loaded = None
CSS = """
.gradio-container {
max-width: 690px !important;
}
footer {
visibility: hidden;
}
"""
JS = """function () {
gradioURL = window.location.href
if (!gradioURL.endsWith('?__theme=dark')) {
window.location.replace(gradioURL + '?__theme=dark');
}
}"""
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
pipe = DiffusionPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
# Function
@spaces.GPU()
def generate_image(prompt, ckpt="4-Step"):
global loaded
prompt = str(translator.translate(prompt, 'English'))
print(prompt)
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if loaded != num_inference_steps:
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
loaded = num_inference_steps
if loaded == 1:
timesteps=[399]
else:
timesteps=[999, 749, 499, 249]
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0, timesteps=timesteps)
return results.images[0]
examples = [
"a cat eating a piece of cheese",
"a ROBOT riding a BLUE horse on Mars, photorealistic",
"Ironman VS Hulk, ultrarealistic",
"a CUTE robot artist painting on an easel",
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
"An alien holding sign board contain word 'Flash', futuristic, neonpunk",
"Kids going to school, Anime style"
]
# Gradio Interface
with gr.Blocks(css=CSS, js=JS, theme="soft") as demo:
gr.HTML("<h1><center>DMD2🦖</center></h1>")
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center><br><center>Multi-Languages, 4-step is higher quality & 2X slower</center></p>")
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter Your Prompt', scale=8)
ckpt = gr.Dropdown(label='Steps',choices=['1-Step', '4-Step'], value='4-Step', interactive=True)
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='DMD2 Generated Image')
gr.Examples(
examples=examples,
inputs=prompt,
outputs=img,
fn=generate_image,
cache_examples="lazy",
)
prompt.submit(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
demo.queue().launch() |