Wuerstchen / app.py
svjack's picture
Update app.py
92b4dc2
raw
history blame
4.59 kB
import os
import shutil
from pathlib import Path
'''
os.system("pip install -U huggingface_hub")
os.system("pip install -U diffusers")
if os.path.exists("wuerstchen"):
shutil.rmtree("wuerstchen")
os.system("git clone https://huggingface.co/warp-ai/wuerstchen")
if os.path.exists("wuerstchen/.git"):
shutil.rmtree("wuerstchen/.git")
'''
import sys
import gradio as gr
import numpy as np
import torch
import random
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS
device = 'cuda' if torch.cuda.is_available() else 'cpu'
assert os.path.exists("wuerstchen")
pipe = AutoPipelineForText2Image.from_pretrained(Path("wuerstchen"), local_files_only = True,
torch_dtype=torch.float32)
'''
pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen",
torch_dtype=torch.float32)
'''
pipe.to(device)
pipe.safety_checker = None
'''
#### 9min a sample (2 cores)
caption = "Anthropomorphic cat dressed as a fire fighter"
images = pipe(
caption,
width=512,
height=512,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, #### length of 30
prior_guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps = 6, #### default num of 12, 6 favour
).images
'''
def process(prompt, num_samples, image_resolution, sample_steps, seed,):
from PIL import Image
with torch.no_grad():
if seed == -1:
seed = random.randint(0, 65535)
#control_image = Image.fromarray(detected_map)
# run inference
#generator = torch.Generator(device=device).manual_seed(seed)
H = image_resolution
W = image_resolution
images = []
for i in range(num_samples):
image = pipe(
prompt,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0,
num_inference_steps = sample_steps,
num_images_per_prompt=1,
height=H, width=W).images[0]
images.append(np.asarray(image))
results = images
return results
#return [255 - detected_map] + results
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Rapid Diffusion model from warp-ai/wuerstchen")
#gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)</h4></b>\n")
with gr.Row():
with gr.Column():
#input_image = gr.Image(source='upload', type="numpy", value = "hate_dog.png")
prompt = gr.Textbox(label="Prompt", value = "Anthropomorphic cat dressed as a fire fighter")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
#low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
#high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=6, step=1)
#scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
#eta = gr.Number(label="eta", value=0.0)
#a_prompt = gr.Textbox(label="Added Prompt", value='')
#n_prompt = gr.Textbox(label="Negative Prompt",
# value='低质量,模糊,混乱')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
#ips = [None, prompt, None, None, num_samples, image_resolution, sample_steps, None, seed, None, None, None]
ips = [prompt, num_samples, image_resolution, sample_steps, seed]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True)
gr.Examples(
[
["A glass of cola, 8k", 1, 512, 8, 10],
["Anthropomorphic cat dressed as a fire fighter", 1, 512, 8, 20],
],
inputs = [prompt, num_samples, image_resolution, sample_steps, seed],
label = "Examples"
)
block.launch(server_name='0.0.0.0')