gokaygokay commited on
Commit
eead421
1 Parent(s): 9432380

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ from fal_client import submit
7
+
8
+ def set_fal_key(api_key):
9
+ os.environ["FAL_KEY"] = api_key
10
+ return "FAL API key set successfully!"
11
+
12
+ def generate_image(api_key, model, prompt, image_size, num_inference_steps, guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed):
13
+ set_fal_key(api_key)
14
+
15
+ arguments = {
16
+ "prompt": prompt,
17
+ "image_size": image_size,
18
+ "num_inference_steps": num_inference_steps,
19
+ "num_images": num_images,
20
+ }
21
+
22
+ if model == "Flux Pro":
23
+ arguments["guidance_scale"] = guidance_scale
24
+ arguments["safety_tolerance"] = safety_tolerance
25
+ fal_model = "fal-ai/flux-pro"
26
+ elif model == "Flux Dev":
27
+ arguments["guidance_scale"] = guidance_scale
28
+ arguments["enable_safety_checker"] = enable_safety_checker
29
+ fal_model = "fal-ai/flux/dev"
30
+ else: # Flux Schnell
31
+ arguments["enable_safety_checker"] = enable_safety_checker
32
+ fal_model = "fal-ai/flux/schnell"
33
+
34
+ if seed != -1:
35
+ arguments["seed"] = seed
36
+
37
+ try:
38
+ handler = submit(fal_model, arguments=arguments)
39
+ result = handler.get()
40
+ images = []
41
+ for img_info in result["images"]:
42
+ img_url = img_info["url"]
43
+ img_response = requests.get(img_url)
44
+ img = Image.open(io.BytesIO(img_response.content))
45
+ images.append(img)
46
+ return images
47
+ except Exception as e:
48
+ return [Image.new('RGB', (512, 512), color='black')]
49
+
50
+ def update_visible_components(model):
51
+ if model == "Flux Pro":
52
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
53
+ elif model == "Flux Dev":
54
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)]
55
+ else: # Flux Schnell
56
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
57
+
58
+ with gr.Blocks() as demo:
59
+ gr.Markdown("# Flux Image Generation")
60
+
61
+ api_key = gr.Textbox(type="password", label="FAL API Key")
62
+
63
+ with gr.Row():
64
+ model = gr.Dropdown(choices=["Flux Pro", "Flux Dev", "Flux Schnell"], label="Model", value="Flux Pro")
65
+ image_size = gr.Dropdown(choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"], label="Image Size", value="landscape_4_3")
66
+
67
+ prompt = gr.Textbox(label="Prompt", lines=3)
68
+ num_inference_steps = gr.Slider(1, 100, value=28, step=1, label="Number of Inference Steps")
69
+ guidance_scale = gr.Slider(0, 20, value=3.5, step=0.1, label="Guidance Scale")
70
+ num_images = gr.Slider(1, 10, value=1, step=1, label="Number of Images")
71
+ safety_tolerance = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6"], label="Safety Tolerance", value="2")
72
+ enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True)
73
+ seed = gr.Number(label="Seed", value=-1)
74
+
75
+ generate_button = gr.Button("Generate Images")
76
+ output_images = gr.Gallery(label="Generated Images")
77
+
78
+ model.change(update_visible_components, inputs=[model], outputs=[guidance_scale, safety_tolerance, enable_safety_checker])
79
+
80
+ generate_button.click(
81
+ generate_image,
82
+ inputs=[api_key, model, prompt, image_size, num_inference_steps, guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed],
83
+ outputs=output_images
84
+ )
85
+
86
+ demo.launch()