File size: 5,583 Bytes
9887d4c
 
 
099c99b
9887d4c
ef1c0b9
96c3ec5
 
9887d4c
 
 
93fd450
af079bb
93fd450
 
 
 
af079bb
b3b1ca1
b39d4ca
7acfd95
 
9887d4c
93fd450
 
af079bb
93fd450
 
af079bb
9887d4c
 
 
96c3ec5
 
 
 
 
ef1c0b9
96c3ec5
9887d4c
 
 
 
 
 
96c3ec5
 
 
 
 
 
 
 
 
 
9887d4c
 
cc280c7
aece66e
96c3ec5
 
 
 
aece66e
9887d4c
 
96c3ec5
9887d4c
 
 
96c3ec5
9887d4c
 
 
 
 
 
 
 
 
 
 
 
621bbdc
5769334
25f07f0
9887d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
96c3ec5
9887d4c
 
 
 
 
96c3ec5
 
9887d4c
96c3ec5
9887d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
099c99b
9887d4c
 
 
 
 
 
 
099c99b
9887d4c
 
 
 
 
 
 
 
 
099c99b
9887d4c
 
 
 
 
099c99b
9887d4c
621bbdc
9887d4c
 
 
 
5ddbee5
4dd28e3
944abe8
96c3ec5
9887d4c
 
f8ac431
 
9887d4c
 
5072f90
9887d4c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import gradio as gr
import numpy as np
import random
from diffusers import AuraFlowPipeline
import torch
import spaces
import uuid
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

#torch.set_float32_matmul_precision("high")

#torch._inductor.config.conv_1x1_as_mm = True
#torch._inductor.config.coordinate_descent_tuning = True
#torch._inductor.config.epilogue_fusion = False
#torch._inductor.config.coordinate_descent_check_all_directions = True

pipe = AuraFlowPipeline.from_pretrained(
	"fal/AuraFlow",
    torch_dtype=torch.float16
).to("cuda")

#pipe.transformer.to(memory_format=torch.channels_last)
#pipe.vae.to(memory_format=torch.channels_last)

#pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
#pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

def save_image(img):
    unique_name = str(uuid.uuid4()) + ".png"
    img.save(unique_name)
    return unique_name

@spaces.GPU
def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=30, progress=gr.Progress(track_tqdm=True)):

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    generator = torch.Generator().manual_seed(seed)
    
    options = { "prompt" : prompt, 
        "negative_prompt" : negative_prompt,
        "width":width,
        "height":height,
        "guidance_scale" : guidance_scale, 
        "num_inference_steps" : num_inference_steps, 
        "generator" : generator }
    images = pipe(**options).images
    image_paths = [save_image(img) for img in images]
    return image_paths, seed

examples = [
    "A photo of a lavender cat",
    "Astronaut in a jungle grasping a sign board contain word 'I love SPACE', cold color palette, muted colors, detailed, futuristic",
    "a cat eating a piece of cheese",
    "a ROBOT riding a BLUE horse on Mars, photorealistic",
    "a cute robot artist painting on an easel, concept art",
    "An alien grasping a sign board contain word 'AuraFlow', futuristic, neonpunk, detailed",
    "Kids going to school, sketch"
]

    
css="""
#col-container {
    margin: 0 auto;
    max-width: 600px;
}
"""

if torch.cuda.is_available():
    power_device = "GPU"
else:
    power_device = "CPU"

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # AuraFlow 0.1
        Demo of the [AuraFlow 0.1](https://huggingface.co/fal/AuraFlow) 6.8B parameters open source diffusion transformer model
        [[blog](https://blog.fal.ai/auraflow/)] [[model](https://huggingface.co/fal/AuraFlow)] [[fal](https://fal.ai/models/fal-ai/aura-flow)]
        """)
        
        with gr.Row():
            
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Gallery(label="Result", columns=1, show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=5,
                lines=4,
                placeholder="Enter a negative prompt",
                value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
            )
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            
            with gr.Row():
                
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    value=5.0,
                )
                
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=28,
                )
        
        gr.Examples(
            examples = examples,
            fn = infer,
            inputs = [prompt],
            outputs = [result, seed],
            cache_examples=True
        )

    gr.on(
        triggers=[run_button.click, prompt.submit, negative_prompt.submit],
        fn = infer,
        inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
        outputs = [result, seed]
    )

demo.queue().launch()