freddyaboulton's picture
Fix infinite
34b7830
#!/usr/bin/env python
# coding: utf-8
import queue
import gradio as gr
import random
import torch
from collections import defaultdict
from diffusers import DiffusionPipeline
from functools import partial
from itertools import zip_longest
from typing import List
from PIL import Image
SELECT_LABEL = "Select as seed"
MODEL_ID = "CompVis/ldm-text2im-large-256"
STEPS = 5 # while running on CPU
ETA = 0.3
GUIDANCE_SCALE = 6
ldm = DiffusionPipeline.from_pretrained(MODEL_ID)
import torch
print(f"cuda: {torch.cuda.is_available()}")
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
state = gr.Variable({
'selected': -1,
'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)]
})
def infer_seeded_image(prompt, seed):
print(f"Prompt: {prompt}, seed: {seed}")
images, _ = infer_grid(prompt, n=1, seeds=[seed])
return images[0]
def infer_grid(prompt, n=6, seeds=[]):
# Unfortunately we have to iterate instead of requesting all images at once,
# because we have no way to get the intermediate generation seeds.
result = defaultdict(list)
for _, seed in zip_longest(range(n), seeds, fillvalue=None):
seed = random.randint(0, 2**32 - 1) if seed is None else seed
_ = torch.manual_seed(seed)
with torch.autocast("cuda"):
images = ldm(
[prompt],
num_inference_steps=STEPS,
eta=ETA,
guidance_scale=GUIDANCE_SCALE
)["sample"]
result["images"].append(images[0])
result["seeds"].append(seed)
return result["images"], result["seeds"]
def infer(prompt, state):
"""
Outputs:
- Grid images (list)
- Seeded Image (Image or None)
- Grid Box with updated visibility
- Seeded Box with updated visibility
"""
grid_images = [None] * 6
image_with_seed = None
visible = (False, False)
if (seed_index := state["selected"]) > -1:
seed = state["seeds"][seed_index]
image_with_seed = infer_seeded_image(prompt, seed)
visible = (False, True)
else:
grid_images, seeds = infer_grid(prompt)
state["seeds"] = seeds
visible = (True, False)
boxes = [gr.Box.update(visible=v) for v in visible]
return grid_images + [image_with_seed] + boxes + [state]
def update_state(selected_index: int, value, state):
if value == '':
others_value = gr.components._Keywords.NO_VALUE
else:
others_value = ''
state["selected"] = selected_index
return [gr.Radio.update(value=others_value) for _ in range(5)] + [state]
def clear_seed(state):
"""Update state of Radio buttons, grid, seeded_box"""
state["selected"] = -1
return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state]
def image_block():
return gr.Image(
interactive=False, show_label=False
).style(
# border = (True, True, False, True),
rounded = (True, True, False, False),
)
def radio_block():
radio = gr.Radio(
choices=[SELECT_LABEL], interactive=True, show_label=False,
).style(
# border = (False, True, True, True),
# rounded = (False, False, True, True)
container=False
)
return radio
gr.Markdown(
"""
<h1><center>Latent Diffusion Demo</center></h1>
<p>Type anything to generate a few images that represent your prompt.
Select one of the results to use as a <b>seed</b> for the next generation:
you can try variations of your prompt starting from the same state and see how it changes.
For example, <i>Labrador in the style of Vermeer</i> could be tweaked to
<i>Labrador in the style of Picasso</i> or <i>Lynx in the style of Van Gogh</i>.
If your prompts are similar, the tweaked result should also have a similar structure
but different details or style.</p>
"""
)
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt", show_label=False, max_lines=1
).style(
border=(True, False, True, True),
# margin=False,
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Run").style(
margin=False,
rounded=(False, True, True, False),
)
## Can we create a Component with these, so it can participate as an output?
with (grid := gr.Box()):
with gr.Row():
with gr.Box().style(border=None):
image1 = image_block()
select1 = radio_block()
with gr.Box().style(border=None):
image2 = image_block()
select2 = radio_block()
with gr.Box().style(border=None):
image3 = image_block()
select3 = radio_block()
with gr.Row():
with gr.Box().style(border=None):
image4 = image_block()
select4 = radio_block()
with gr.Box().style(border=None):
image5 = image_block()
select5 = radio_block()
with gr.Box().style(border=None):
image6 = image_block()
select6 = radio_block()
images = [image1, image2, image3, image4, image5, image6]
selectors = [select1, select2, select3, select4, select5, select6]
for i, radio in enumerate(selectors):
others = list(filter(lambda s: s != radio, selectors))
radio.change(
partial(update_state, i),
inputs=[radio, state],
outputs=others + [state],
queue=False
)
with (seeded_box := gr.Box()):
seeded_image = image_block()
clear_seed_button = gr.Button("Return to Grid")
seeded_box.visible = False
clear_seed_button.click(
clear_seed,
inputs=[state],
outputs=selectors + [grid, seeded_box] + [state]
)
all_images = images + [seeded_image]
boxes = [grid, seeded_box]
infer_outputs = all_images + boxes + [state]
text.submit(
infer,
inputs=[text, state],
outputs=infer_outputs
)
btn.click(
infer,
inputs=[text, state],
outputs=infer_outputs
)
demo.launch(enable_queue=True)