|
import gradio as gr |
|
import fal_client |
|
from fal_client.client import FalClientError |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import traceback |
|
import os |
|
|
|
def generate_image(api_key, prompt, image_size, seed, sync_mode, num_images, enable_safety_checker, safety_tolerance): |
|
try: |
|
os.environ['FAL_KEY'] = api_key |
|
|
|
arguments = { |
|
"prompt": prompt, |
|
"image_size": image_size, |
|
"num_images": num_images, |
|
"enable_safety_checker": enable_safety_checker, |
|
} |
|
|
|
arguments["safety_tolerance"] = safety_tolerance |
|
|
|
if seed is not None and seed != "": |
|
arguments["seed"] = int(seed) |
|
|
|
if sync_mode is not None: |
|
arguments["sync_mode"] = sync_mode |
|
|
|
|
|
print(f"Request Body: {arguments}") |
|
|
|
handler = fal_client.submit( |
|
"fal-ai/flux-pro/v1.1", |
|
arguments=arguments, |
|
) |
|
result = handler.get() |
|
|
|
|
|
print(f"Response: {result}") |
|
|
|
images = [] |
|
for img_info in result['images']: |
|
img_url = img_info['url'] |
|
response = requests.get(img_url) |
|
img = Image.open(BytesIO(response.content)) |
|
images.append(img) |
|
return [gr.update(value=images, visible=True), gr.update(value=str(result), visible=True)] |
|
|
|
except FalClientError as e: |
|
error_messages = [] |
|
for error_obj in e.args[0]: |
|
error_messages.append(error_obj['msg']) |
|
|
|
error_msg = "Errors:\n" + "\n".join(error_messages) |
|
print(error_msg) |
|
return [gr.update(value=[]), gr.update(value=error_msg)] |
|
except Exception as e: |
|
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return [gr.update(value=[]), gr.update(value=error_msg)] |
|
|
|
def update_safety_tolerance_visibility(enable_safety): |
|
return gr.update(visible=enable_safety, value="6") |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# FLUX1.1 [pro] Text-to-Image Generator") |
|
gr.Markdown("Get your API key at https://fal.ai/dashboard/keys") |
|
|
|
with gr.Row(): |
|
api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API key here") |
|
with gr.Row(): |
|
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here") |
|
with gr.Row(): |
|
image_size = gr.Dropdown( |
|
label="Image Size", |
|
choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"], |
|
value="landscape_4_3" |
|
) |
|
num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1) |
|
with gr.Row(): |
|
seed = gr.Textbox(label="Seed (optional)", placeholder="Enter a number for reproducible results") |
|
sync_mode = gr.Checkbox(label="Sync Mode", value=False) |
|
with gr.Row(): |
|
enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True) |
|
safety_tolerance = gr.Dropdown( |
|
label="Safety Tolerance", |
|
choices=["1", "2", "3", "4", "5"], |
|
value="2", |
|
visible=True |
|
) |
|
gr.Markdown("**Note:** Safety Tolerance: 1 is the most strict, 6 is the most permissive. Default is 2.") |
|
|
|
generate_btn = gr.Button("Generate Image") |
|
output_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2) |
|
response_output = gr.Textbox(label="Response", visible=True) |
|
|
|
enable_safety_checker.change( |
|
fn=update_safety_tolerance_visibility, |
|
inputs=[enable_safety_checker], |
|
outputs=[safety_tolerance] |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=[api_key, prompt, image_size, seed, sync_mode, num_images, enable_safety_checker, safety_tolerance], |
|
outputs=[output_gallery, response_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |