import gradio as gr import fal_client from fal_client.client import FalClientError import requests from PIL import Image from io import BytesIO import traceback import os def generate_image(api_key, prompt, image_size, seed, sync_mode, num_images, enable_safety_checker, safety_tolerance): try: os.environ['FAL_KEY'] = api_key arguments = { "prompt": prompt, "image_size": image_size, "num_images": num_images, "enable_safety_checker": enable_safety_checker, } arguments["safety_tolerance"] = safety_tolerance if seed is not None and seed != "": arguments["seed"] = int(seed) if sync_mode is not None: arguments["sync_mode"] = sync_mode # Log the actual request body print(f"Request Body: {arguments}") handler = fal_client.submit( "fal-ai/flux-pro/v1.1", arguments=arguments, ) result = handler.get() # Display and log the response print(f"Response: {result}") images = [] for img_info in result['images']: img_url = img_info['url'] response = requests.get(img_url) img = Image.open(BytesIO(response.content)) images.append(img) return [gr.update(value=images, visible=True), gr.update(value=str(result), visible=True)] except FalClientError as e: error_messages = [] for error_obj in e.args[0]: error_messages.append(error_obj['msg']) error_msg = "Errors:\n" + "\n".join(error_messages) print(error_msg) return [gr.update(value=[]), gr.update(value=error_msg)] except Exception as e: error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" print(error_msg) return [gr.update(value=[]), gr.update(value=error_msg)] def update_safety_tolerance_visibility(enable_safety): return gr.update(visible=enable_safety, value="6") with gr.Blocks() as demo: gr.Markdown("# FLUX1.1 [pro] Text-to-Image Generator") gr.Markdown("Get your API key at https://fal.ai/dashboard/keys") with gr.Row(): api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API key here") with gr.Row(): prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here") with gr.Row(): image_size = gr.Dropdown( label="Image Size", choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"], value="landscape_4_3" ) num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1) with gr.Row(): seed = gr.Textbox(label="Seed (optional)", placeholder="Enter a number for reproducible results") sync_mode = gr.Checkbox(label="Sync Mode", value=False) with gr.Row(): enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True) safety_tolerance = gr.Dropdown( label="Safety Tolerance", choices=["1", "2", "3", "4", "5"], value="2", visible=True ) gr.Markdown("**Note:** Safety Tolerance: 1 is the most strict, 6 is the most permissive. Default is 2.") generate_btn = gr.Button("Generate Image") output_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2) response_output = gr.Textbox(label="Response", visible=True) enable_safety_checker.change( fn=update_safety_tolerance_visibility, inputs=[enable_safety_checker], outputs=[safety_tolerance] ) generate_btn.click( fn=generate_image, inputs=[api_key, prompt, image_size, seed, sync_mode, num_images, enable_safety_checker, safety_tolerance], outputs=[output_gallery, response_output] ) if __name__ == "__main__": demo.launch()