WLgiflora / app.py
shweaung's picture
Update app.py
58549d1 verified
raw
history blame
5.72 kB
import os
import io
import random
import requests
import gradio as gr
import numpy as np
from PIL import Image
MAX_SEED = np.iinfo(np.int32).max
API_TOKEN = os.getenv("HF_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"}
timeout = 100
def split_image(input_image, num_splits=4):
output_images = []
box_size = 512 # Each split image will be 512x512
coordinates = [
(0, 0, box_size, box_size), # Top-left
(box_size, 0, 1024, box_size), # Top-right
(0, box_size, box_size, 1024), # Bottom-left
(box_size, box_size, 1024, 1024) # Bottom-right
]
# Crop each region using predefined coordinates
for box in coordinates:
output_images.append(input_image.crop(box))
return output_images
# Function to export split images to GIF
def export_to_gif(images, output_path, fps=4):
# Calculate duration per frame in milliseconds based on fps
duration = int(1000 / fps)
# Create a GIF from the list of images
images[0].save(
output_path,
save_all=True,
append_images=images[1:],
duration=duration, # Duration between frames
loop=0 # Loop forever
)
def predict(prompt, seed=-1, randomize_seed=True, guidance_scale=3.5, num_inference_steps=28, lora_id="black-forest-labs/FLUX.1-dev", progress=gr.Progress(track_tqdm=True)):
prompt_template = f"""a 2x2 total 4 grid of frames, showing consecutive stills from a looped gif of {prompt}"""
if lora_id.strip() == "" or lora_id == None:
lora_id = "black-forest-labs/FLUX.1-dev"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
key = random.randint(0, 999)
API_URL = "https://api-inference.huggingface.co/models/"+ lora_id.strip()
API_TOKEN = random.choice([os.getenv("HF_TOKEN")])
headers = {"Authorization": f"Bearer {API_TOKEN}"}
payload = {
"inputs": prompt_template,
"steps": num_inference_steps,
"cfg_scale": guidance_scale,
"seed": seed,
}
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
if response.status_code != 200:
print(f"Error: Failed to get image. Response status: {response.status_code}")
print(f"Response content: {response.text}")
if response.status_code == 503:
raise gr.Error(f"{response.status_code} : The model is being loaded")
raise gr.Error(f"{response.status_code}")
try:
image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))
print(f'\033[1mGeneration {key} completed!\033[0m ({prompt})')
split_images = split_image(image, num_splits=4)
# Path to save the GIF
gif_path = "output.gif"
# Export the split images to GIF
export_to_gif(split_images, gif_path, fps=4)
return gif_path, image, seed
except Exception as e:
print(f"Error when trying to open the image: {e}")
return None
demo = gr.Interface(fn=predict, inputs="text", outputs="image")
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
#stills{max-height:160px}
"""
examples = [
"a cat waving its paws in the air",
"a panda moving their hips from side to side",
"a flower going through the process of blooming"
]
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Walone Gif Generator. Gif Animation ပုံထုတ်စနစ်")
gr.Markdown("Create GIFs with Walone. Based on Fux Model.")
gr.Markdown("Add LoRA (if needed) in Advanced Settings. For better results, include a description of the motion in your prompt.")
# gr.Markdown("For better results include a description of the motion in your prompt")
# with gr.Row():
# with gr.Column():
with gr.Row():
prompt = gr.Text(label="Prompt", show_label=False, max_lines=4, show_copy_button = True, placeholder="Enter your prompt", container=False)
submit = gr.Button("Submit", scale=0)
output = gr.Image(label="GIF", show_label=False)
output_stills = gr.Image(label="stills", show_label=False, elem_id="stills")
with gr.Accordion("Advanced Settings", open=False):
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path (optional)", placeholder="multimodalart/vintage-ads-flux")
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
gr.Examples(
examples=examples,
fn=predict,
inputs=[prompt],
outputs=[output, output_stills, seed],
cache_examples="lazy"
)
gr.on(
triggers=[submit.click, prompt.submit],
fn=predict,
inputs=[prompt, seed, randomize_seed, guidance_scale, num_inference_steps, custom_lora],
outputs = [output, output_stills, seed]
)
demo.launch()