File size: 6,934 Bytes
c755c7b
 
34b7830
c755c7b
 
 
 
 
 
 
 
 
 
 
 
 
6b09a6a
c755c7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b09a6a
c755c7b
 
 
34b7830
c755c7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34b7830
 
c755c7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/env python
# coding: utf-8
import queue
import gradio as gr
import random
import torch
from collections import defaultdict
from diffusers import DiffusionPipeline
from functools import partial
from itertools import zip_longest
from typing import List
from PIL import Image

SELECT_LABEL = "Select as seed"

MODEL_ID = "CompVis/ldm-text2im-large-256"
STEPS = 5   # while running on CPU
ETA = 0.3
GUIDANCE_SCALE = 6

ldm = DiffusionPipeline.from_pretrained(MODEL_ID)

import torch
print(f"cuda: {torch.cuda.is_available()}")

with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
    state = gr.Variable({
        'selected': -1,
        'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)]
    })

    def infer_seeded_image(prompt, seed):
        print(f"Prompt: {prompt}, seed: {seed}")
        images, _ = infer_grid(prompt, n=1, seeds=[seed])
        return images[0]

    def infer_grid(prompt, n=6, seeds=[]):
        # Unfortunately we have to iterate instead of requesting all images at once,
        # because we have no way to get the intermediate generation seeds.
        result = defaultdict(list)
        for _, seed in zip_longest(range(n), seeds, fillvalue=None):
            seed = random.randint(0, 2**32 - 1) if seed is None else seed
            _ = torch.manual_seed(seed)
            with torch.autocast("cuda"):
                images = ldm(
                    [prompt],
                    num_inference_steps=STEPS,
                    eta=ETA,
                    guidance_scale=GUIDANCE_SCALE
                )["sample"]
            result["images"].append(images[0])
            result["seeds"].append(seed)
        return result["images"], result["seeds"]

    def infer(prompt, state):
        """
        Outputs:
        - Grid images (list)
        - Seeded Image (Image or None)
        - Grid Box with updated visibility
        - Seeded Box with updated visibility
        """
        grid_images = [None] * 6
        image_with_seed = None
        visible = (False, False)

        if (seed_index := state["selected"]) > -1:
            seed = state["seeds"][seed_index]
            image_with_seed = infer_seeded_image(prompt, seed)
            visible = (False, True)
        else:
            grid_images, seeds = infer_grid(prompt)
            state["seeds"] = seeds
            visible = (True, False)

        boxes = [gr.Box.update(visible=v) for v in visible]
        return grid_images + [image_with_seed] + boxes + [state]

    def update_state(selected_index: int, value, state):
        if value == '':
            others_value = gr.components._Keywords.NO_VALUE
        else:
            others_value = ''
            state["selected"] = selected_index
        return [gr.Radio.update(value=others_value) for _ in range(5)] + [state]

    def clear_seed(state):
        """Update state of Radio buttons, grid, seeded_box"""
        state["selected"] = -1
        return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state]

    def image_block():
        return gr.Image(
            interactive=False, show_label=False
        ).style(
            # border = (True, True, False, True),
            rounded = (True, True, False, False),
        )

    def radio_block():
        radio = gr.Radio(
            choices=[SELECT_LABEL], interactive=True, show_label=False,
        ).style(
            # border = (False, True, True, True),
            # rounded = (False, False, True, True)
            container=False
        )
        return radio

    gr.Markdown(
        """
        <h1><center>Latent Diffusion Demo</center></h1>
        <p>Type anything to generate a few images that represent your prompt.
        Select one of the results to use as a <b>seed</b> for the next generation:
        you can try variations of your prompt starting from the same state and see how it changes.
        For example, <i>Labrador in the style of Vermeer</i> could be tweaked to
        <i>Labrador in the style of Picasso</i> or <i>Lynx in the style of Van Gogh</i>.
        If your prompts are similar, the tweaked result should also have a similar structure
        but different details or style.</p>
        """
    )
    with gr.Group():
        with gr.Box():
            with gr.Row().style(mobile_collapse=False, equal_height=True):
                text = gr.Textbox(
                    label="Enter your prompt", show_label=False, max_lines=1
                ).style(
                    border=(True, False, True, True),
                    # margin=False,
                    rounded=(True, False, False, True),
                    container=False,
                )
                btn = gr.Button("Run").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )

        ## Can we create a Component with these, so it can participate as an output?
        with (grid := gr.Box()):
            with gr.Row():
                with gr.Box().style(border=None):
                    image1 = image_block()
                    select1 = radio_block()
                with gr.Box().style(border=None):
                    image2 = image_block()
                    select2 = radio_block()
                with gr.Box().style(border=None):
                    image3 = image_block()
                    select3 = radio_block()
            with gr.Row():
                with gr.Box().style(border=None):
                    image4 = image_block()
                    select4 = radio_block()
                with gr.Box().style(border=None):
                    image5 = image_block()
                    select5 = radio_block()
                with gr.Box().style(border=None):
                    image6 = image_block()
                    select6 = radio_block()

        images = [image1, image2, image3, image4, image5, image6]
        selectors = [select1, select2, select3, select4, select5, select6]

        for i, radio in enumerate(selectors):
            others = list(filter(lambda s: s != radio, selectors))
            radio.change(
                partial(update_state, i),
                inputs=[radio, state],
                outputs=others + [state],
                queue=False
            )

    with (seeded_box := gr.Box()):
        seeded_image = image_block()
        clear_seed_button = gr.Button("Return to Grid")
    seeded_box.visible = False
    clear_seed_button.click(
        clear_seed,
        inputs=[state],
        outputs=selectors + [grid, seeded_box] + [state]
    )

    all_images = images + [seeded_image]
    boxes = [grid, seeded_box]
    infer_outputs = all_images + boxes + [state]

    text.submit(
        infer,
        inputs=[text, state],
        outputs=infer_outputs
    )
    btn.click(
        infer,
        inputs=[text, state],
        outputs=infer_outputs
    )

demo.launch(enable_queue=True)