FLUX.1-schnell / app.py
Jeff850's picture
Update app.py
a782412 verified
import gradio as gr
import numpy as np
import random
import spaces
import torch
import os
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Include your Hugging Face access token
hf_token = os.getenv("waffles")
# Load the diffusion pipeline with the access token
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=hf_token).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
DEFAULT_INFERENCE_STEPS = 4
@spaces.GPU(duration=90)
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_images=1, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
images = []
for _ in range(num_images):
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=DEFAULT_INFERENCE_STEPS,
generator=generator,
guidance_scale=0 # Fixed at 0
).images[0]
images.append(image)
return images, seed
examples = [
"a white husky knocking everything down in a living room",
"a tuxedo cat with a waffle in her mouth",
"an anime Chiweenie Dog wearing a hoodie",
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX.1 [schnell]
12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=5,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", show_label=False)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
num_images = gr.Slider(
label="Number of images",
minimum=1,
maximum=4,
step=1,
value=1,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt],
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, seed, randomize_seed, width, height, num_images],
outputs=[result, seed]
)
demo.launch()