File size: 4,644 Bytes
eead421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58cbf34
 
 
 
 
 
eead421
58cbf34
 
 
 
 
 
eead421
58cbf34
 
 
 
 
 
 
 
 
987d247
58cbf34
987d247
 
 
58cbf34
 
eead421
 
58cbf34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eead421
58cbf34
 
eead421
58cbf34
eead421
58cbf34
 
 
 
 
 
 
eead421
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import requests
from PIL import Image
import io
import os
from fal_client import submit

def set_fal_key(api_key):
    os.environ["FAL_KEY"] = api_key
    return "FAL API key set successfully!"

def generate_image(api_key, model, prompt, image_size, num_inference_steps, guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed):
    set_fal_key(api_key)
    
    arguments = {
        "prompt": prompt,
        "image_size": image_size,
        "num_inference_steps": num_inference_steps,
        "num_images": num_images,
    }

    if model == "Flux Pro":
        arguments["guidance_scale"] = guidance_scale
        arguments["safety_tolerance"] = safety_tolerance
        fal_model = "fal-ai/flux-pro"
    elif model == "Flux Dev":
        arguments["guidance_scale"] = guidance_scale
        arguments["enable_safety_checker"] = enable_safety_checker
        fal_model = "fal-ai/flux/dev"
    else:  # Flux Schnell
        arguments["enable_safety_checker"] = enable_safety_checker
        fal_model = "fal-ai/flux/schnell"

    if seed != -1:
        arguments["seed"] = seed

    try:
        handler = submit(fal_model, arguments=arguments)
        result = handler.get()
        images = []
        for img_info in result["images"]:
            img_url = img_info["url"]
            img_response = requests.get(img_url)
            img = Image.open(io.BytesIO(img_response.content))
            images.append(img)
        return images
    except Exception as e:
        return [Image.new('RGB', (512, 512), color='black')]

def update_visible_components(model):
    if model == "Flux Pro":
        return [
            gr.update(visible=True, value=28),
            gr.update(visible=True, value=3.5),
            gr.update(visible=True, value="2"),
            gr.update(visible=False)
        ]
    elif model == "Flux Dev":
        return [
            gr.update(visible=True, value=28),
            gr.update(visible=True, value=3.5),
            gr.update(visible=False),
            gr.update(visible=True, value=True)
        ]
    else:  # Flux Schnell
        return [
            gr.update(visible=True, value=4),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=True, value=True)
        ]

with gr.Blocks(theme='bethecloud/storj_theme') as demo:
    gr.HTML("""
    <h1 align="center">FLUX.1 Image Generation</h1>
    <p align="center">
    <a href="https://blackforestlabs.ai/" target="_blank">[Black Forest Labs]</a>
    <a href="https://blackforestlabs.ai/announcing-black-forest-labs/" target="_blank">[Blog]</a>
    <a href="https://fal.ai/models/fal-ai/flux-pro" target="_blank">[FLUX.1 [pro] Model FAL]</a>
    </p>
    """)

    with gr.Row():
        with gr.Column(scale=1):
            api_key = gr.Textbox(type="password", label="FAL API Key")
            model = gr.Dropdown(
                label="Model",
                choices=["Flux Pro", "Flux Dev", "Flux Schnell"],
                value="Flux Pro"
            )
            prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Add your prompt here")
            image_size = gr.Dropdown(
                choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"],
                label="Image Size",
                value="landscape_4_3"
            )
            
            with gr.Accordion("Advanced settings", open=False):
                num_inference_steps = gr.Slider(1, 100, 28, step=1, label="Number of Inference Steps")
                guidance_scale = gr.Slider(0, 20, 3.5, step=0.1, label="Guidance Scale")
                num_images = gr.Slider(1, 10, 1, step=1, label="Number of Images")
                safety_tolerance = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6"], label="Safety Tolerance", value="2")
                enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True)
                seed = gr.Number(label="Seed", value=-1)

            generate_btn = gr.Button("Generate Image")

        with gr.Column(scale=1):
            output_gallery = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False)

    model.change(update_visible_components, inputs=[model], outputs=[num_inference_steps, guidance_scale, safety_tolerance, enable_safety_checker])

    generate_btn.click(
        fn=generate_image,
        inputs=[
            api_key, model, prompt, image_size, num_inference_steps,
            guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed
        ],
        outputs=[output_gallery]
    )

demo.launch()