Spaces:
Running
Running
import os | |
import shutil | |
from pathlib import Path | |
''' | |
os.system("pip install -U huggingface_hub") | |
os.system("pip install -U diffusers") | |
if os.path.exists("wuerstchen"): | |
shutil.rmtree("wuerstchen") | |
os.system("git clone https://huggingface.co/warp-ai/wuerstchen") | |
if os.path.exists("wuerstchen/.git"): | |
shutil.rmtree("wuerstchen/.git") | |
''' | |
import sys | |
import gradio as gr | |
import numpy as np | |
import torch | |
import random | |
from diffusers import AutoPipelineForText2Image | |
from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
assert os.path.exists("wuerstchen") | |
pipe = AutoPipelineForText2Image.from_pretrained(Path("wuerstchen"), local_files_only = True, | |
torch_dtype=torch.float32) | |
''' | |
pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", | |
torch_dtype=torch.float32) | |
''' | |
pipe.to(device) | |
pipe.safety_checker = None | |
''' | |
#### 9min a sample (2 cores) | |
caption = "Anthropomorphic cat dressed as a fire fighter" | |
images = pipe( | |
caption, | |
width=512, | |
height=512, | |
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, #### length of 30 | |
prior_guidance_scale=4.0, | |
num_images_per_prompt=1, | |
num_inference_steps = 6, #### default num of 12, 6 favour | |
).images | |
''' | |
def process(prompt, num_samples, image_resolution, sample_steps, seed,): | |
from PIL import Image | |
with torch.no_grad(): | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
#control_image = Image.fromarray(detected_map) | |
# run inference | |
#generator = torch.Generator(device=device).manual_seed(seed) | |
H = image_resolution | |
W = image_resolution | |
images = [] | |
for i in range(num_samples): | |
image = pipe( | |
prompt, | |
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, | |
prior_guidance_scale=4.0, | |
num_inference_steps = sample_steps, | |
num_images_per_prompt=1, | |
height=H, width=W).images[0] | |
images.append(np.asarray(image)) | |
results = images | |
return results | |
#return [255 - detected_map] + results | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## Rapid Diffusion model from warp-ai/wuerstchen") | |
#gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)</h4></b>\n") | |
with gr.Row(): | |
with gr.Column(): | |
#input_image = gr.Image(source='upload', type="numpy", value = "hate_dog.png") | |
prompt = gr.Textbox(label="Prompt", value = "Anthropomorphic cat dressed as a fire fighter") | |
run_button = gr.Button(label="Run") | |
with gr.Accordion("Advanced options", open=False): | |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) | |
#low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) | |
#high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) | |
sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=6, step=1) | |
#scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
#eta = gr.Number(label="eta", value=0.0) | |
#a_prompt = gr.Textbox(label="Added Prompt", value='') | |
#n_prompt = gr.Textbox(label="Negative Prompt", | |
# value='低质量,模糊,混乱') | |
with gr.Column(): | |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | |
#ips = [None, prompt, None, None, num_samples, image_resolution, sample_steps, None, seed, None, None, None] | |
ips = [prompt, num_samples, image_resolution, sample_steps, seed] | |
run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True) | |
gr.Examples( | |
[ | |
["A glass of cola, 8k", 1, 512, 8, 10], | |
["Anthropomorphic cat dressed as a fire fighter", 1, 512, 8, 20], | |
], | |
inputs = [prompt, num_samples, image_resolution, sample_steps, seed], | |
label = "Examples" | |
) | |
block.launch(server_name='0.0.0.0') | |