File size: 3,006 Bytes
eb6334b
 
46b0f64
 
 
eb6334b
46b0f64
 
eb6334b
46b0f64
 
 
 
 
 
 
eb6334b
 
 
 
46b0f64
eb6334b
98307a5
46b0f64
 
 
eb6334b
 
46b0f64
 
eb6334b
 
 
46b0f64
eb6334b
98307a5
 
eb6334b
46b0f64
eb6334b
46b0f64
eb6334b
 
 
 
 
 
 
 
46b0f64
eb6334b
46b0f64
 
 
98307a5
eb6334b
46b0f64
 
 
 
eb6334b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a032011
eb6334b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a032011
eb6334b
 
 
 
 
 
 
 
 
 
98307a5
46b0f64
eb6334b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os

import gradio as gr
import spaces
import torch
from diffusers import AutoPipelineForText2Image
from loguru import logger

SUPPORTED_MODELS = [
    "stabilityai/sdxl-turbo",
    "stabilityai/stable-diffusion-3-medium-diffusers",
    "stabilityai/stable-diffusion-xl-base-1.0",
]
DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"


model = os.environ.get("MODEL_ID", DEFAULT_MODEL)
gpu_duration = int(os.environ.get("GPU_DURATION", 60))


def load_pipeline(model):
    return AutoPipelineForText2Image.from_pretrained(
        model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
    )


logger.debug(f"Loading pipeline: {dict(model=model)}")
pipe = load_pipeline(model).to("cuda")


@logger.catch(reraise=True)
@spaces.GPU(duration=gpu_duration)
def infer(
    prompt: str,
    negative_prompt: str | None,
    num_inference_steps: int,
    guidance_scale: float,
    progress=gr.Progress(track_tqdm=True),
):
    logger.info(f"Starting image generation: {dict(model=model, prompt=prompt)}")

    additional_args = {
        k: v
        for k, v in dict(
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        ).items()
        if v
    }

    logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}")

    images = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        **additional_args,
    ).images
    return images[0]


with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Text-to-Image")
        gr.Markdown(f"## Model: `{model}`")

        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )

            run_button = gr.Button("Run", scale=0, variant="primary")

        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
            )

            with gr.Row():
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=0,
                    maximum=100,
                    step=1,
                    value=0,
                )

                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=100.0,
                    step=0.1,
                    value=0.0,
                )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,
            num_inference_steps,
            guidance_scale,
        ],
        outputs=[result],
    )

if __name__ == "__main__":
    demo.launch()