File size: 6,934 Bytes
c755c7b 34b7830 c755c7b 6b09a6a c755c7b 6b09a6a c755c7b 34b7830 c755c7b 34b7830 c755c7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
#!/usr/bin/env python
# coding: utf-8
import queue
import gradio as gr
import random
import torch
from collections import defaultdict
from diffusers import DiffusionPipeline
from functools import partial
from itertools import zip_longest
from typing import List
from PIL import Image
SELECT_LABEL = "Select as seed"
MODEL_ID = "CompVis/ldm-text2im-large-256"
STEPS = 5 # while running on CPU
ETA = 0.3
GUIDANCE_SCALE = 6
ldm = DiffusionPipeline.from_pretrained(MODEL_ID)
import torch
print(f"cuda: {torch.cuda.is_available()}")
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
state = gr.Variable({
'selected': -1,
'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)]
})
def infer_seeded_image(prompt, seed):
print(f"Prompt: {prompt}, seed: {seed}")
images, _ = infer_grid(prompt, n=1, seeds=[seed])
return images[0]
def infer_grid(prompt, n=6, seeds=[]):
# Unfortunately we have to iterate instead of requesting all images at once,
# because we have no way to get the intermediate generation seeds.
result = defaultdict(list)
for _, seed in zip_longest(range(n), seeds, fillvalue=None):
seed = random.randint(0, 2**32 - 1) if seed is None else seed
_ = torch.manual_seed(seed)
with torch.autocast("cuda"):
images = ldm(
[prompt],
num_inference_steps=STEPS,
eta=ETA,
guidance_scale=GUIDANCE_SCALE
)["sample"]
result["images"].append(images[0])
result["seeds"].append(seed)
return result["images"], result["seeds"]
def infer(prompt, state):
"""
Outputs:
- Grid images (list)
- Seeded Image (Image or None)
- Grid Box with updated visibility
- Seeded Box with updated visibility
"""
grid_images = [None] * 6
image_with_seed = None
visible = (False, False)
if (seed_index := state["selected"]) > -1:
seed = state["seeds"][seed_index]
image_with_seed = infer_seeded_image(prompt, seed)
visible = (False, True)
else:
grid_images, seeds = infer_grid(prompt)
state["seeds"] = seeds
visible = (True, False)
boxes = [gr.Box.update(visible=v) for v in visible]
return grid_images + [image_with_seed] + boxes + [state]
def update_state(selected_index: int, value, state):
if value == '':
others_value = gr.components._Keywords.NO_VALUE
else:
others_value = ''
state["selected"] = selected_index
return [gr.Radio.update(value=others_value) for _ in range(5)] + [state]
def clear_seed(state):
"""Update state of Radio buttons, grid, seeded_box"""
state["selected"] = -1
return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state]
def image_block():
return gr.Image(
interactive=False, show_label=False
).style(
# border = (True, True, False, True),
rounded = (True, True, False, False),
)
def radio_block():
radio = gr.Radio(
choices=[SELECT_LABEL], interactive=True, show_label=False,
).style(
# border = (False, True, True, True),
# rounded = (False, False, True, True)
container=False
)
return radio
gr.Markdown(
"""
<h1><center>Latent Diffusion Demo</center></h1>
<p>Type anything to generate a few images that represent your prompt.
Select one of the results to use as a <b>seed</b> for the next generation:
you can try variations of your prompt starting from the same state and see how it changes.
For example, <i>Labrador in the style of Vermeer</i> could be tweaked to
<i>Labrador in the style of Picasso</i> or <i>Lynx in the style of Van Gogh</i>.
If your prompts are similar, the tweaked result should also have a similar structure
but different details or style.</p>
"""
)
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt", show_label=False, max_lines=1
).style(
border=(True, False, True, True),
# margin=False,
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Run").style(
margin=False,
rounded=(False, True, True, False),
)
## Can we create a Component with these, so it can participate as an output?
with (grid := gr.Box()):
with gr.Row():
with gr.Box().style(border=None):
image1 = image_block()
select1 = radio_block()
with gr.Box().style(border=None):
image2 = image_block()
select2 = radio_block()
with gr.Box().style(border=None):
image3 = image_block()
select3 = radio_block()
with gr.Row():
with gr.Box().style(border=None):
image4 = image_block()
select4 = radio_block()
with gr.Box().style(border=None):
image5 = image_block()
select5 = radio_block()
with gr.Box().style(border=None):
image6 = image_block()
select6 = radio_block()
images = [image1, image2, image3, image4, image5, image6]
selectors = [select1, select2, select3, select4, select5, select6]
for i, radio in enumerate(selectors):
others = list(filter(lambda s: s != radio, selectors))
radio.change(
partial(update_state, i),
inputs=[radio, state],
outputs=others + [state],
queue=False
)
with (seeded_box := gr.Box()):
seeded_image = image_block()
clear_seed_button = gr.Button("Return to Grid")
seeded_box.visible = False
clear_seed_button.click(
clear_seed,
inputs=[state],
outputs=selectors + [grid, seeded_box] + [state]
)
all_images = images + [seeded_image]
boxes = [grid, seeded_box]
infer_outputs = all_images + boxes + [state]
text.submit(
infer,
inputs=[text, state],
outputs=infer_outputs
)
btn.click(
infer,
inputs=[text, state],
outputs=infer_outputs
)
demo.launch(enable_queue=True) |