mdm / app.py
pcuenq's picture
pcuenq HF staff
Update title
2077884
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
import spaces
import os
import logging
import shlex
import time
from dataclasses import dataclass
from typing import Optional
import gradio as gr
import simple_parsing
import yaml
from einops import rearrange, repeat
import numpy as np
import torch
from huggingface_hub import snapshot_download
from pathlib import Path
from transformers import T5ForConditionalGeneration
from torchvision.utils import make_grid
from ml_mdm import helpers, reader
from ml_mdm.config import get_arguments, get_model, get_pipeline
from ml_mdm.language_models import factory
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download destination
models = Path("models")
logging.basicConfig(
level=getattr(logging, "INFO", None),
format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s",
datefmt="%H:%M:%S",
)
def download_all_models():
# Cache language model in the standard location
_ = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")
# Download the vision models we use in the demo
snapshot_download("pcuenq/mdm-flickr-64", local_dir=models/"mdm-flickr-64")
snapshot_download("pcuenq/mdm-flickr-256", local_dir=models/"mdm-flickr-256")
snapshot_download("pcuenq/mdm-flickr-1024", local_dir=models/"mdm-flickr-1024")
def dividable(n):
for i in range(int(np.sqrt(n)), 0, -1):
if n % i == 0:
break
return i, n // i
def generate_lm_outputs(device, sample, tokenizer, language_model, args):
with torch.no_grad():
lm_outputs, lm_mask = language_model(sample, tokenizer)
sample["lm_outputs"] = lm_outputs
sample["lm_mask"] = lm_mask
return sample
def setup_models(args, device):
input_channels = 3
# load the language model
tokenizer, language_model = factory.create_lm(args, device=device)
language_model_dim = language_model.embed_dim
args.unet_config.conditioning_feature_dim = language_model_dim
denoising_model = get_model(args.model)(
input_channels, input_channels, args.unet_config
).to(device)
diffusion_model = get_pipeline(args.model)(
denoising_model, args.diffusion_config
).to(device)
# denoising_model.print_size(args.sample_image_size)
return tokenizer, language_model, diffusion_model
def plot_logsnr(logsnrs, total_steps):
import matplotlib.pyplot as plt
x = 1 - np.arange(len(logsnrs)) / (total_steps - 1)
plt.plot(x, np.asarray(logsnrs))
plt.xlabel("timesteps")
plt.ylabel("LogSNR")
plt.grid(True)
plt.xlim(0, 1)
plt.ylim(-20, 10)
plt.gca().invert_xaxis()
# Convert the plot to a numpy array
fig = plt.gcf()
fig.canvas.draw()
image = np.array(fig.canvas.renderer._renderer)
plt.close()
return image
@dataclass
class GLOBAL_DATA:
reader_config: Optional[reader.ReaderConfig] = None
tokenizer = None
args = None
language_model = None
diffusion_model = None
override_args = ""
ckpt_name = ""
global_config = GLOBAL_DATA()
def stop_run():
return (
gr.update(value="Run", variant="primary", visible=True),
gr.update(visible=False),
)
def get_model_type(config_file):
with open(config_file, "r") as f:
d = yaml.safe_load(f)
return d.get("model", d.get("vision_model", "unet"))
@spaces.GPU
def generate(
ckpt_name="mdm-flickr-64",
prompt="a chair",
input_template="",
negative_prompt="",
negative_template="",
batch_size=20,
guidance_scale=7.5,
threshold_function="clip",
num_inference_steps=250,
eta=0,
save_diffusion_path=False,
show_diffusion_path=False,
show_xt=False,
reader_config="",
seed=10,
comment="",
override_args="",
output_inner=False,
):
np.random.seed(seed)
torch.random.manual_seed(seed)
if len(input_template) > 0:
prompt = input_template.format(prompt=prompt)
if len(negative_template) > 0:
negative_prompt = negative_prompt + negative_template
print(f"Postive: {prompt} / Negative: {negative_prompt}")
vision_model_file = models/ckpt_name/"vis_model.pth"
if not os.path.exists(vision_model_file):
logging.info(f"Did not generate because {vision_model_file} does not exist")
return None, None, f"{vision_model_file} does not exist", None, None
if (
global_config.ckpt_name != ckpt_name
or global_config.override_args != override_args
):
# Identify model type
model_type = get_model_type(models/ckpt_name/"config.yaml")
# reload the arguments
args = get_arguments(
shlex.split(override_args + f" --model {model_type}"),
mode="demo",
additional_config_paths=[models/ckpt_name/"config.yaml"],
)
helpers.print_args(args)
# setup model when the parent task changed.
args.vocab_file = str(models/ckpt_name/args.vocab_file)
tokenizer, language_model, diffusion_model = setup_models(args, device)
try:
other_items = diffusion_model.model.load(vision_model_file)
except Exception as e:
logging.error(f"failed to load {vision_model_file}", exc_info=e)
return None, None, "Loading Model Error", None, None
# setup global configs
global_config.batch_num = -1 # reset batch num
global_config.args = args
global_config.override_args = override_args
global_config.tokenizer = tokenizer
global_config.language_model = language_model
global_config.diffusion_model = diffusion_model
global_config.reader_config = args.reader_config
global_config.ckpt_name = ckpt_name
else:
args = global_config.args
tokenizer = global_config.tokenizer
language_model = global_config.language_model
diffusion_model = global_config.diffusion_model
tokenizer = global_config.tokenizer
sample = {}
sample["text"] = [negative_prompt, prompt] if guidance_scale != 1 else [prompt]
sample["tokens"] = np.asarray(
reader.process_text(sample["text"], tokenizer, args.reader_config)
)
sample = generate_lm_outputs(device, sample, tokenizer, language_model, args)
assert args.sample_image_size != -1
# set up thresholding
from samplers import ThresholdType
diffusion_model.sampler._config.threshold_function = {
"clip": ThresholdType.CLIP,
"dynamic (Imagen)": ThresholdType.DYNAMIC,
"dynamic (DeepFloyd)": ThresholdType.DYNAMIC_IF,
"none": ThresholdType.NONE,
}[threshold_function]
output_comments = f"{comment}\n"
bsz = batch_size
with torch.no_grad():
if bsz > 1:
sample["lm_outputs"] = repeat(
sample["lm_outputs"], "b n d -> (b r) n d", r=bsz
)
sample["lm_mask"] = repeat(sample["lm_mask"], "b n -> (b r) n", r=bsz)
num_samples = bsz
original, outputs, logsnrs = [], [], []
logging.info(f"Starting to sample from the model")
start_time = time.time()
for step, result in enumerate(
diffusion_model.sample(
num_samples,
sample,
args.sample_image_size,
device,
return_sequence=False,
num_inference_steps=num_inference_steps,
ddim_eta=eta,
guidance_scale=guidance_scale,
resample_steps=True,
disable_bar=False,
yield_output=True,
yield_full=True,
output_inner=output_inner,
)
):
x0, x_t, extra = result
if step < num_inference_steps:
g = extra[0][0, 0, 0, 0].cpu()
logsnrs += [torch.log(g / (1 - g))]
output = x0 if not show_xt else x_t
output = torch.clamp(output * 0.5 + 0.5, min=0, max=1).cpu()
original += [
output if not output_inner else output[..., -args.sample_image_size :]
]
output = (
make_grid(output, nrow=dividable(bsz)[0]).permute(1, 2, 0).numpy() * 255
).astype(np.uint8)
outputs += [output]
output_video_path = None
if step == num_inference_steps and save_diffusion_path:
import imageio
writer = imageio.get_writer("temp_output.mp4", fps=32)
for output in outputs:
writer.append_data(output)
writer.close()
output_video_path = "temp_output.mp4"
if any(diffusion_model.model.vision_model.is_temporal):
data = rearrange(
original[-1],
"(a b) c (n h) (m w) -> (n m) (a h) (b w) c",
a=dividable(bsz)[0],
n=4,
m=4,
)
data = (data.numpy() * 255).astype(np.uint8)
writer = imageio.get_writer("temp_output.mp4", fps=4)
for d in data:
writer.append_data(d)
writer.close()
if show_diffusion_path or (step == num_inference_steps):
yield output, plot_logsnr(
logsnrs, num_inference_steps
), output_comments + f"Step ({step} / {num_inference_steps}) Time ({time.time() - start_time:.4}s)", output_video_path, gr.update(
value="Run",
variant="primary",
visible=(step == num_inference_steps),
), gr.update(
value="Stop", variant="stop", visible=(step != num_inference_steps)
)
def main():
download_all_models()
# get the language model outputs
example_texts = open("data/prompts_demo.tsv").readlines()
css = """
#config-accordion, #logs-accordion {color: black !important;}
.dark #config-accordion, .dark #logs-accordion {color: white !important;}
.stop {background: darkred !important;}
"""
with gr.Blocks(
title="Demo of Text-to-Image Diffusion",
theme="EveryPizza/Cartoony-Gradio-Theme",
css=css,
) as demo:
with gr.Row(equal_height=True):
header = """
# Matryoshka MDM – Text-to-Image Diffusion Model Web Demo by MLR
This is a demo of model `mdm-flickr-64`. For additional models, please check [our repo](https://github.com/apple/ml-mdm).
### Usage
- Select examples below or manually input prompt
- Change more advanced settings such as inference steps.
"""
gr.Markdown(header)
with gr.Row(equal_height=False):
pid = gr.State()
ckpt_name = gr.Label("mdm-flickr-64", visible=False)
# with gr.Row():
prompt_input = gr.Textbox(label="Input prompt")
with gr.Row(equal_height=False):
with gr.Column(scale=1):
with gr.Row(equal_height=False):
guidance_scale = gr.Slider(
value=7.5,
minimum=0.0,
maximum=50,
step=0.1,
label="Guidance scale",
)
with gr.Row(equal_height=False):
batch_size = gr.Slider(
value=4, minimum=1, maximum=32, step=1, label="Number of images"
)
with gr.Column(scale=1):
with gr.Row(equal_height=False):
with gr.Column(scale=1):
save_diffusion_path = gr.Checkbox(
value=True, label="Show diffusion path as a video"
)
show_diffusion_path = gr.Checkbox(
value=False, label="Show diffusion progress"
)
show_xt = gr.Checkbox(value=False, label="Show predicted x_t")
with gr.Column(scale=1):
output_inner = gr.Checkbox(
value=False,
label="Output inner UNet (High-res models Only)",
visible=False,
)
with gr.Row(equal_height=False):
comment = gr.Textbox(value="", label="Comments to the model (optional)")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
output_image = gr.Image(value=None, label="Output image")
with gr.Column(scale=2):
output_video = gr.Video(value=None, label="Diffusion Path")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
with gr.Accordion(
"Advanced settings", open=False, elem_id="config-accordion"
):
input_template = gr.Dropdown(
[
"",
"breathtaking {prompt}. award-winning, professional, highly detailed",
"anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
"concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
"ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"cinematic film still {prompt}. shallow depth of field, vignette, highly detailed, high budget Hollywood movie, bokeh, cinemascope, moody",
"analog film photo {prompt}. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage",
"vaporwave synthwave style {prompt}. cyberpunk, neon, vibes, stunningly beautiful, crisp, detailed, sleek, ultramodern, high contrast, cinematic composition",
"isometric style {prompt}. vibrant, beautiful, crisp, detailed, ultra detailed, intricate",
"low-poly style {prompt}. ambient occlusion, low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition",
"claymation style {prompt}. sculpture, clay art, centered composition, play-doh",
"professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
"origami style {prompt}. paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition",
"pixel-art {prompt}. low-res, blocky, pixel art style, 16-bit graphics",
],
value="",
label="Positive Template (by default, not use)",
)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
negative_prompt_input = gr.Textbox(
value="", label="Negative prompt"
)
with gr.Column(scale=1):
negative_template = gr.Dropdown(
[
"",
"anime, cartoon, graphic, text, painting, crayon, graphite, abstract glitch, blurry",
"photo, deformed, black and white, realism, disfigured, low contrast",
"photo, photorealistic, realism, ugly",
"photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
"drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
"anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
"painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
"illustration, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
"deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy, realistic, photographic",
"noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo",
"ugly, deformed, noisy, low poly, blurry, painting",
],
value="",
label="Negative Template (by default, not use)",
)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
threshold_function = gr.Dropdown(
[
"clip",
"dynamic (Imagen)",
"dynamic (DeepFloyd)",
"none",
],
value="dynamic (DeepFloyd)",
label="Thresholding",
)
with gr.Column(scale=1):
reader_config = gr.Dropdown(
["configs/datasets/reader_config.yaml"],
value="configs/datasets/reader_config.yaml",
label="Reader Config",
)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
num_inference_steps = gr.Slider(
value=50,
minimum=1,
maximum=2000,
step=1,
label="# of steps",
)
with gr.Column(scale=1):
eta = gr.Slider(
value=0,
minimum=0,
maximum=1,
step=0.05,
label="DDIM eta",
)
seed = gr.Slider(
value=137,
minimum=0,
maximum=2147483647,
step=1,
label="Random seed",
)
override_args = gr.Textbox(
value="--reader_config.max_token_length 128 --reader_config.max_caption_length 512",
label="Override model arguments (optional)",
)
run_btn = gr.Button(value="Run", variant="primary")
stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
with gr.Column(scale=2):
with gr.Accordion(
"Addditional outputs", open=False, elem_id="output-accordion"
):
with gr.Row(equal_height=True):
output_text = gr.Textbox(value=None, label="System output")
with gr.Row(equal_height=True):
logsnr_fig = gr.Image(value=None, label="Noise schedule")
run_event = run_btn.click(
fn=generate,
inputs=[
ckpt_name,
prompt_input,
input_template,
negative_prompt_input,
negative_template,
batch_size,
guidance_scale,
threshold_function,
num_inference_steps,
eta,
save_diffusion_path,
show_diffusion_path,
show_xt,
reader_config,
seed,
comment,
override_args,
output_inner,
],
outputs=[
output_image,
logsnr_fig,
output_text,
output_video,
run_btn,
stop_btn,
],
)
stop_btn.click(
fn=stop_run,
outputs=[run_btn, stop_btn],
cancels=[run_event],
queue=False,
)
# example0 = gr.Examples(
# [
# ["mdm-flickr-64", 64, 50, 0],
# ["mdm-flickr-256", 16, 100, 0],
# ["mdm-flickr-1024", 4, 250, 1],
# ],
# inputs=[ckpt_name, batch_size, num_inference_steps, eta],
# )
example1 = gr.Examples(
examples=[[t.strip()] for t in example_texts],
inputs=[prompt_input],
)
launch_args = {}
demo.queue(default_concurrency_limit=1).launch(**launch_args)
if __name__ == "__main__":
main()