# For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. import spaces import os import logging import shlex import time from dataclasses import dataclass from typing import Optional import gradio as gr import simple_parsing import yaml from einops import rearrange, repeat import numpy as np import torch from huggingface_hub import snapshot_download from pathlib import Path from transformers import T5ForConditionalGeneration from torchvision.utils import make_grid from ml_mdm import helpers, reader from ml_mdm.config import get_arguments, get_model, get_pipeline from ml_mdm.language_models import factory device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Download destination models = Path("models") logging.basicConfig( level=getattr(logging, "INFO", None), format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s", datefmt="%H:%M:%S", ) def download_all_models(): # Cache language model in the standard location _ = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl") # Download the vision models we use in the demo snapshot_download("pcuenq/mdm-flickr-64", local_dir=models/"mdm-flickr-64") snapshot_download("pcuenq/mdm-flickr-256", local_dir=models/"mdm-flickr-256") snapshot_download("pcuenq/mdm-flickr-1024", local_dir=models/"mdm-flickr-1024") def dividable(n): for i in range(int(np.sqrt(n)), 0, -1): if n % i == 0: break return i, n // i def generate_lm_outputs(device, sample, tokenizer, language_model, args): with torch.no_grad(): lm_outputs, lm_mask = language_model(sample, tokenizer) sample["lm_outputs"] = lm_outputs sample["lm_mask"] = lm_mask return sample def setup_models(args, device): input_channels = 3 # load the language model tokenizer, language_model = factory.create_lm(args, device=device) language_model_dim = language_model.embed_dim args.unet_config.conditioning_feature_dim = language_model_dim denoising_model = get_model(args.model)( input_channels, input_channels, args.unet_config ).to(device) diffusion_model = get_pipeline(args.model)( denoising_model, args.diffusion_config ).to(device) # denoising_model.print_size(args.sample_image_size) return tokenizer, language_model, diffusion_model def plot_logsnr(logsnrs, total_steps): import matplotlib.pyplot as plt x = 1 - np.arange(len(logsnrs)) / (total_steps - 1) plt.plot(x, np.asarray(logsnrs)) plt.xlabel("timesteps") plt.ylabel("LogSNR") plt.grid(True) plt.xlim(0, 1) plt.ylim(-20, 10) plt.gca().invert_xaxis() # Convert the plot to a numpy array fig = plt.gcf() fig.canvas.draw() image = np.array(fig.canvas.renderer._renderer) plt.close() return image @dataclass class GLOBAL_DATA: reader_config: Optional[reader.ReaderConfig] = None tokenizer = None args = None language_model = None diffusion_model = None override_args = "" ckpt_name = "" global_config = GLOBAL_DATA() def stop_run(): return ( gr.update(value="Run", variant="primary", visible=True), gr.update(visible=False), ) def get_model_type(config_file): with open(config_file, "r") as f: d = yaml.safe_load(f) return d.get("model", d.get("vision_model", "unet")) @spaces.GPU def generate( ckpt_name="mdm-flickr-64", prompt="a chair", input_template="", negative_prompt="", negative_template="", batch_size=20, guidance_scale=7.5, threshold_function="clip", num_inference_steps=250, eta=0, save_diffusion_path=False, show_diffusion_path=False, show_xt=False, reader_config="", seed=10, comment="", override_args="", output_inner=False, ): np.random.seed(seed) torch.random.manual_seed(seed) if len(input_template) > 0: prompt = input_template.format(prompt=prompt) if len(negative_template) > 0: negative_prompt = negative_prompt + negative_template print(f"Postive: {prompt} / Negative: {negative_prompt}") vision_model_file = models/ckpt_name/"vis_model.pth" if not os.path.exists(vision_model_file): logging.info(f"Did not generate because {vision_model_file} does not exist") return None, None, f"{vision_model_file} does not exist", None, None if ( global_config.ckpt_name != ckpt_name or global_config.override_args != override_args ): # Identify model type model_type = get_model_type(models/ckpt_name/"config.yaml") # reload the arguments args = get_arguments( shlex.split(override_args + f" --model {model_type}"), mode="demo", additional_config_paths=[models/ckpt_name/"config.yaml"], ) helpers.print_args(args) # setup model when the parent task changed. args.vocab_file = str(models/ckpt_name/args.vocab_file) tokenizer, language_model, diffusion_model = setup_models(args, device) try: other_items = diffusion_model.model.load(vision_model_file) except Exception as e: logging.error(f"failed to load {vision_model_file}", exc_info=e) return None, None, "Loading Model Error", None, None # setup global configs global_config.batch_num = -1 # reset batch num global_config.args = args global_config.override_args = override_args global_config.tokenizer = tokenizer global_config.language_model = language_model global_config.diffusion_model = diffusion_model global_config.reader_config = args.reader_config global_config.ckpt_name = ckpt_name else: args = global_config.args tokenizer = global_config.tokenizer language_model = global_config.language_model diffusion_model = global_config.diffusion_model tokenizer = global_config.tokenizer sample = {} sample["text"] = [negative_prompt, prompt] if guidance_scale != 1 else [prompt] sample["tokens"] = np.asarray( reader.process_text(sample["text"], tokenizer, args.reader_config) ) sample = generate_lm_outputs(device, sample, tokenizer, language_model, args) assert args.sample_image_size != -1 # set up thresholding from samplers import ThresholdType diffusion_model.sampler._config.threshold_function = { "clip": ThresholdType.CLIP, "dynamic (Imagen)": ThresholdType.DYNAMIC, "dynamic (DeepFloyd)": ThresholdType.DYNAMIC_IF, "none": ThresholdType.NONE, }[threshold_function] output_comments = f"{comment}\n" bsz = batch_size with torch.no_grad(): if bsz > 1: sample["lm_outputs"] = repeat( sample["lm_outputs"], "b n d -> (b r) n d", r=bsz ) sample["lm_mask"] = repeat(sample["lm_mask"], "b n -> (b r) n", r=bsz) num_samples = bsz original, outputs, logsnrs = [], [], [] logging.info(f"Starting to sample from the model") start_time = time.time() for step, result in enumerate( diffusion_model.sample( num_samples, sample, args.sample_image_size, device, return_sequence=False, num_inference_steps=num_inference_steps, ddim_eta=eta, guidance_scale=guidance_scale, resample_steps=True, disable_bar=False, yield_output=True, yield_full=True, output_inner=output_inner, ) ): x0, x_t, extra = result if step < num_inference_steps: g = extra[0][0, 0, 0, 0].cpu() logsnrs += [torch.log(g / (1 - g))] output = x0 if not show_xt else x_t output = torch.clamp(output * 0.5 + 0.5, min=0, max=1).cpu() original += [ output if not output_inner else output[..., -args.sample_image_size :] ] output = ( make_grid(output, nrow=dividable(bsz)[0]).permute(1, 2, 0).numpy() * 255 ).astype(np.uint8) outputs += [output] output_video_path = None if step == num_inference_steps and save_diffusion_path: import imageio writer = imageio.get_writer("temp_output.mp4", fps=32) for output in outputs: writer.append_data(output) writer.close() output_video_path = "temp_output.mp4" if any(diffusion_model.model.vision_model.is_temporal): data = rearrange( original[-1], "(a b) c (n h) (m w) -> (n m) (a h) (b w) c", a=dividable(bsz)[0], n=4, m=4, ) data = (data.numpy() * 255).astype(np.uint8) writer = imageio.get_writer("temp_output.mp4", fps=4) for d in data: writer.append_data(d) writer.close() if show_diffusion_path or (step == num_inference_steps): yield output, plot_logsnr( logsnrs, num_inference_steps ), output_comments + f"Step ({step} / {num_inference_steps}) Time ({time.time() - start_time:.4}s)", output_video_path, gr.update( value="Run", variant="primary", visible=(step == num_inference_steps), ), gr.update( value="Stop", variant="stop", visible=(step != num_inference_steps) ) def main(): download_all_models() # get the language model outputs example_texts = open("data/prompts_demo.tsv").readlines() css = """ #config-accordion, #logs-accordion {color: black !important;} .dark #config-accordion, .dark #logs-accordion {color: white !important;} .stop {background: darkred !important;} """ with gr.Blocks( title="Demo of Text-to-Image Diffusion", theme="EveryPizza/Cartoony-Gradio-Theme", css=css, ) as demo: with gr.Row(equal_height=True): header = """ # Matryoshka MDM – Text-to-Image Diffusion Model Web Demo by MLR This is a demo of model `mdm-flickr-64`. For additional models, please check [our repo](https://github.com/apple/ml-mdm). ### Usage - Select examples below or manually input prompt - Change more advanced settings such as inference steps. """ gr.Markdown(header) with gr.Row(equal_height=False): pid = gr.State() ckpt_name = gr.Label("mdm-flickr-64", visible=False) # with gr.Row(): prompt_input = gr.Textbox(label="Input prompt") with gr.Row(equal_height=False): with gr.Column(scale=1): with gr.Row(equal_height=False): guidance_scale = gr.Slider( value=7.5, minimum=0.0, maximum=50, step=0.1, label="Guidance scale", ) with gr.Row(equal_height=False): batch_size = gr.Slider( value=4, minimum=1, maximum=32, step=1, label="Number of images" ) with gr.Column(scale=1): with gr.Row(equal_height=False): with gr.Column(scale=1): save_diffusion_path = gr.Checkbox( value=True, label="Show diffusion path as a video" ) show_diffusion_path = gr.Checkbox( value=False, label="Show diffusion progress" ) show_xt = gr.Checkbox(value=False, label="Show predicted x_t") with gr.Column(scale=1): output_inner = gr.Checkbox( value=False, label="Output inner UNet (High-res models Only)", visible=False, ) with gr.Row(equal_height=False): comment = gr.Textbox(value="", label="Comments to the model (optional)") with gr.Row(equal_height=False): with gr.Column(scale=2): output_image = gr.Image(value=None, label="Output image") with gr.Column(scale=2): output_video = gr.Video(value=None, label="Diffusion Path") with gr.Row(equal_height=False): with gr.Column(scale=2): with gr.Accordion( "Advanced settings", open=False, elem_id="config-accordion" ): input_template = gr.Dropdown( [ "", "breathtaking {prompt}. award-winning, professional, highly detailed", "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed", "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed", "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed", "cinematic film still {prompt}. shallow depth of field, vignette, highly detailed, high budget Hollywood movie, bokeh, cinemascope, moody", "analog film photo {prompt}. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage", "vaporwave synthwave style {prompt}. cyberpunk, neon, vibes, stunningly beautiful, crisp, detailed, sleek, ultramodern, high contrast, cinematic composition", "isometric style {prompt}. vibrant, beautiful, crisp, detailed, ultra detailed, intricate", "low-poly style {prompt}. ambient occlusion, low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition", "claymation style {prompt}. sculpture, clay art, centered composition, play-doh", "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting", "origami style {prompt}. paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition", "pixel-art {prompt}. low-res, blocky, pixel art style, 16-bit graphics", ], value="", label="Positive Template (by default, not use)", ) with gr.Row(equal_height=False): with gr.Column(scale=1): negative_prompt_input = gr.Textbox( value="", label="Negative prompt" ) with gr.Column(scale=1): negative_template = gr.Dropdown( [ "", "anime, cartoon, graphic, text, painting, crayon, graphite, abstract glitch, blurry", "photo, deformed, black and white, realism, disfigured, low contrast", "photo, photorealistic, realism, ugly", "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white", "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured", "illustration, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy, realistic, photographic", "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo", "ugly, deformed, noisy, low poly, blurry, painting", ], value="", label="Negative Template (by default, not use)", ) with gr.Row(equal_height=False): with gr.Column(scale=1): threshold_function = gr.Dropdown( [ "clip", "dynamic (Imagen)", "dynamic (DeepFloyd)", "none", ], value="dynamic (DeepFloyd)", label="Thresholding", ) with gr.Column(scale=1): reader_config = gr.Dropdown( ["configs/datasets/reader_config.yaml"], value="configs/datasets/reader_config.yaml", label="Reader Config", ) with gr.Row(equal_height=False): with gr.Column(scale=1): num_inference_steps = gr.Slider( value=50, minimum=1, maximum=2000, step=1, label="# of steps", ) with gr.Column(scale=1): eta = gr.Slider( value=0, minimum=0, maximum=1, step=0.05, label="DDIM eta", ) seed = gr.Slider( value=137, minimum=0, maximum=2147483647, step=1, label="Random seed", ) override_args = gr.Textbox( value="--reader_config.max_token_length 128 --reader_config.max_caption_length 512", label="Override model arguments (optional)", ) run_btn = gr.Button(value="Run", variant="primary") stop_btn = gr.Button(value="Stop", variant="stop", visible=False) with gr.Column(scale=2): with gr.Accordion( "Addditional outputs", open=False, elem_id="output-accordion" ): with gr.Row(equal_height=True): output_text = gr.Textbox(value=None, label="System output") with gr.Row(equal_height=True): logsnr_fig = gr.Image(value=None, label="Noise schedule") run_event = run_btn.click( fn=generate, inputs=[ ckpt_name, prompt_input, input_template, negative_prompt_input, negative_template, batch_size, guidance_scale, threshold_function, num_inference_steps, eta, save_diffusion_path, show_diffusion_path, show_xt, reader_config, seed, comment, override_args, output_inner, ], outputs=[ output_image, logsnr_fig, output_text, output_video, run_btn, stop_btn, ], ) stop_btn.click( fn=stop_run, outputs=[run_btn, stop_btn], cancels=[run_event], queue=False, ) # example0 = gr.Examples( # [ # ["mdm-flickr-64", 64, 50, 0], # ["mdm-flickr-256", 16, 100, 0], # ["mdm-flickr-1024", 4, 250, 1], # ], # inputs=[ckpt_name, batch_size, num_inference_steps, eta], # ) example1 = gr.Examples( examples=[[t.strip()] for t in example_texts], inputs=[prompt_input], ) launch_args = {} demo.queue(default_concurrency_limit=1).launch(**launch_args) if __name__ == "__main__": main()