Spaces:
Configuration error
Configuration error
#!/usr/bin/env python | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import argparse | |
import gc | |
import os | |
import random | |
import warnings | |
from dataclasses import dataclass, field | |
from datetime import datetime | |
from typing import List, Optional, Tuple, Union | |
import gradio as gr | |
import numpy as np | |
import pyrallis | |
import torch | |
from gradio.components import Image, Textbox | |
from torchvision.utils import _log_api_usage_once, make_grid, save_image | |
warnings.filterwarnings("ignore") # ignore warning | |
from asset.examples import examples | |
from diffusion import DPMS, FlowEuler, SASolverSampler | |
from diffusion.data.datasets.utils import * | |
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode | |
from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor | |
from diffusion.utils.config import SanaConfig | |
from diffusion.utils.dist_utils import flush | |
from tools.download import find_model | |
# from diffusion.utils.misc import read_config | |
MAX_SEED = np.iinfo(np.int32).max | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, help="config path") | |
return parser.parse_known_args()[0] | |
class SanaInference(SanaConfig): | |
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config | |
model_path: str = field( | |
default="output/Sana_1600M/SANA.pth", metadata={"help": "Path to the model file (positional)"} | |
) | |
output: str = "./output" | |
bs: int = 1 | |
image_size: int = 1024 | |
cfg_scale: float = 5.0 | |
pag_scale: float = 2.0 | |
seed: int = 42 | |
step: int = -1 | |
port: int = 7788 | |
custom_image_size: Optional[int] = None | |
shield_model_path: str = field( | |
default="google/shieldgemma-2b", | |
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."}, | |
) | |
def ndarr_image( | |
tensor: Union[torch.Tensor, List[torch.Tensor]], | |
**kwargs, | |
) -> None: | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(save_image) | |
grid = make_grid(tensor, **kwargs) | |
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
return ndarr | |
def set_env(seed=0): | |
torch.manual_seed(seed) | |
torch.set_grad_enabled(False) | |
for _ in range(30): | |
torch.randn(1, 4, args.image_size, args.image_size) | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: | |
"""Returns binned height and width.""" | |
ar = float(height / width) | |
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) | |
default_hw = ratios[closest_ratio] | |
return int(default_hw[0]), int(default_hw[1]) | |
def generate_img( | |
prompt, | |
sampler, | |
sample_steps, | |
scale, | |
pag_scale=1.0, | |
guidance_type="classifier-free", | |
seed=0, | |
randomize_seed=False, | |
base_size=1024, | |
height=1024, | |
width=1024, | |
): | |
flush() | |
gc.collect() | |
torch.cuda.empty_cache() | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
set_env(seed) | |
base_ratios = eval(f"ASPECT_RATIO_{base_size}_TEST") | |
os.makedirs(f"output/demo/online_demo_prompts/", exist_ok=True) | |
save_promt_path = f"output/demo/online_demo_prompts/tested_prompts{datetime.now().date()}.txt" | |
with open(save_promt_path, "a") as f: | |
f.write(f"{seed}: {prompt}" + "\n") | |
print(f"{seed}: {prompt}") | |
prompt_clean, prompt_show, _, _, _ = prepare_prompt_ar(prompt, base_ratios, device=device) # ar for aspect ratio | |
orig_height, orig_width = height, width | |
height, width = classify_height_width_bin(height, width, ratios=base_ratios) | |
prompt_show += ( | |
f"\n Sample steps: {sample_steps}, CFG Scale: {scale}, PAG Scale: {pag_scale}, flow_shift: {flow_shift}" | |
) | |
prompt_clean = prompt_clean.strip() | |
if isinstance(prompt_clean, str): | |
prompts = [prompt_clean] | |
# prepare text feature | |
if not config.text_encoder.chi_prompt: | |
max_length_all = max_sequence_length | |
prompts_all = prompts | |
else: | |
chi_prompt = "\n".join(config.text_encoder.chi_prompt) | |
prompts_all = [chi_prompt + prompt for prompt in prompts] | |
num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) | |
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 # magic number 2: [bos], [_] | |
caption_token = tokenizer( | |
prompts_all, max_length=max_length_all, padding="max_length", truncation=True, return_tensors="pt" | |
).to(device) | |
select_index = [0] + list(range(-max_sequence_length + 1, 0)) | |
caption_embs = text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][:, :, select_index] | |
emb_masks = caption_token.attention_mask[:, select_index] | |
null_y = null_caption_embs.repeat(len(prompts), 1, 1)[:, None] | |
n = len(prompts) | |
latent_size_h, latent_size_w = height // config.vae.vae_downsample_rate, width // config.vae.vae_downsample_rate | |
z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device, dtype=weight_dtype) | |
model_kwargs = dict(data_info={"img_hw": (latent_size_h, latent_size_w), "aspect_ratio": 1.0}, mask=emb_masks) | |
print(f"Latent Size: {z.shape}") | |
# Sample images: | |
if sampler == "dpm-solver": | |
# Create sampling noise: | |
dpm_solver = DPMS( | |
model.forward_with_dpmsolver, | |
condition=caption_embs, | |
uncondition=null_y, | |
cfg_scale=scale, | |
model_kwargs=model_kwargs, | |
) | |
samples = dpm_solver.sample( | |
z, | |
steps=sample_steps, | |
order=2, | |
skip_type="time_uniform", | |
method="multistep", | |
) | |
elif sampler == "sa-solver": | |
# Create sampling noise: | |
sa_solver = SASolverSampler(model.forward_with_dpmsolver, device=device) | |
samples = sa_solver.sample( | |
S=sample_steps, | |
batch_size=n, | |
shape=(4, latent_size_h, latent_size_w), | |
eta=1, | |
conditioning=caption_embs, | |
unconditional_conditioning=null_y, | |
unconditional_guidance_scale=scale, | |
model_kwargs=model_kwargs, | |
)[0] | |
elif sampler == "flow_euler": | |
flow_solver = FlowEuler( | |
model, condition=caption_embs, uncondition=null_y, cfg_scale=scale, model_kwargs=model_kwargs | |
) | |
samples = flow_solver.sample( | |
z, | |
steps=sample_steps, | |
) | |
elif sampler == "flow_dpm-solver": | |
if not (pag_scale > 1.0 and config.model.attn_type == "linear"): | |
guidance_type = "classifier-free" | |
dpm_solver = DPMS( | |
model, | |
condition=caption_embs, | |
uncondition=null_y, | |
guidance_type=guidance_type, | |
cfg_scale=scale, | |
pag_scale=pag_scale, | |
pag_applied_layers=pag_applied_layers, | |
model_type="flow", | |
model_kwargs=model_kwargs, | |
schedule="FLOW", | |
) | |
samples = dpm_solver.sample( | |
z, | |
steps=sample_steps, | |
order=2, | |
skip_type="time_uniform_flow", | |
method="multistep", | |
flow_shift=flow_shift, | |
) | |
else: | |
raise ValueError(f"{args.sampling_algo} is not defined") | |
samples = samples.to(weight_dtype) | |
samples = vae_decode(config.vae.vae_type, vae, samples) | |
samples = resize_and_crop_tensor(samples, orig_width, orig_height) | |
display_model_info = ( | |
f"Model path: {args.model_path},\nBase image size: {args.image_size}, \nSampling Algo: {sampler}" | |
) | |
return ndarr_image(samples, normalize=True, value_range=(-1, 1)), prompt_show, display_model_info, seed | |
if __name__ == "__main__": | |
from diffusion.utils.logger import get_root_logger | |
args = get_args() | |
config = args = pyrallis.parse(config_class=SanaInference, config_path=args.config) | |
# config = read_config(args.config) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger = get_root_logger() | |
args.image_size = config.model.image_size | |
assert args.image_size in [ | |
256, | |
512, | |
1024, | |
2048, | |
4096, | |
], "We only provide pre-trained models for 256x256, 512x512, 1024x1024, 2048x2048 and 4096x4096 resolutions." | |
# only support fixed latent size currently | |
latent_size = config.model.image_size // config.vae.vae_downsample_rate | |
max_sequence_length = config.text_encoder.model_max_length | |
pe_interpolation = config.model.pe_interpolation | |
micro_condition = config.model.micro_condition | |
pag_applied_layers = config.model.pag_applied_layers | |
flow_shift = config.scheduler.flow_shift | |
if config.model.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif config.model.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
elif config.model.mixed_precision == "fp32": | |
weight_dtype = torch.float32 | |
else: | |
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined") | |
logger.info(f"Inference with {weight_dtype}") | |
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(weight_dtype) | |
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name, device=device) | |
# model setting | |
pred_sigma = getattr(config.scheduler, "pred_sigma", True) | |
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma | |
model_kwargs = { | |
"input_size": latent_size, | |
"pe_interpolation": config.model.pe_interpolation, | |
"config": config, | |
"model_max_length": config.text_encoder.model_max_length, | |
"qk_norm": config.model.qk_norm, | |
"micro_condition": config.model.micro_condition, | |
"caption_channels": text_encoder.config.hidden_size, | |
"y_norm": config.text_encoder.y_norm, | |
"attn_type": config.model.attn_type, | |
"ffn_type": config.model.ffn_type, | |
"mlp_ratio": config.model.mlp_ratio, | |
"mlp_acts": list(config.model.mlp_acts), | |
"in_channels": config.vae.vae_latent_dim, | |
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor, | |
"use_pe": config.model.use_pe, | |
"linear_head_dim": config.model.linear_head_dim, | |
"pred_sigma": pred_sigma, | |
"learn_sigma": learn_sigma, | |
} | |
model = build_model( | |
config.model.model, use_fp32_attention=config.model.get("fp32_attention", False), **model_kwargs | |
).to(device) | |
# model = build_model(config.model, **model_kwargs).to(device) | |
logger.info( | |
f"{model.__class__.__name__}:{config.model.model}, Model Parameters: {sum(p.numel() for p in model.parameters()):,}" | |
) | |
logger.info("Generating sample from ckpt: %s" % args.model_path) | |
state_dict = find_model(args.model_path) | |
if "pos_embed" in state_dict["state_dict"]: | |
del state_dict["state_dict"]["pos_embed"] | |
missing, unexpected = model.load_state_dict(state_dict["state_dict"], strict=False) | |
logger.warning(f"Missing keys: {missing}") | |
logger.warning(f"Unexpected keys: {unexpected}") | |
model.eval().to(weight_dtype) | |
base_ratios = eval(f"ASPECT_RATIO_{args.image_size}_TEST") | |
null_caption_token = tokenizer( | |
"", max_length=max_sequence_length, padding="max_length", truncation=True, return_tensors="pt" | |
).to(device) | |
null_caption_embs = text_encoder(null_caption_token.input_ids, attention_mask=null_caption_token.attention_mask)[0] | |
model_size = "1.6" if "D20" in args.model_path else "0.6" | |
title = f""" | |
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'> | |
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/> | |
</div> | |
""" | |
DESCRIPTION = f""" | |
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p> | |
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p> | |
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github(coming soon)]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p> | |
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p> | |
""" | |
if model_size == "0.6": | |
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
demo = gr.Interface( | |
fn=generate_img, | |
inputs=[ | |
Textbox( | |
label="Note: If you want to specify a aspect ratio or determine a customized height and width, " | |
"use --ar h:w (or --aspect_ratio h:w) or --hw h:w. If no aspect ratio or hw is given, all setting will be default.", | |
placeholder="Please enter your prompt. \n", | |
), | |
gr.Radio( | |
choices=["dpm-solver", "sa-solver", "flow_dpm-solver", "flow_euler"], | |
label=f"Sampler", | |
interactive=True, | |
value="flow_dpm-solver", | |
), | |
gr.Slider(label="Sample Steps", minimum=1, maximum=100, value=20, step=1), | |
gr.Slider(label="Guidance Scale", minimum=1.0, maximum=30.0, value=5.0, step=0.1), | |
gr.Slider(label="PAG Scale", minimum=1.0, maximum=10.0, value=2.5, step=0.5), | |
gr.Radio( | |
choices=["classifier-free", "classifier-free_PAG", "classifier-free_PAG_seq"], | |
label=f"Guidance Type", | |
interactive=True, | |
value="classifier-free_PAG_seq", | |
), | |
gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
), | |
gr.Checkbox(label="Randomize seed", value=True), | |
gr.Radio( | |
choices=[256, 512, 1024, 2048, 4096], | |
label=f"Base Size", | |
interactive=True, | |
value=args.image_size, | |
), | |
gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=6000, | |
step=32, | |
value=args.image_size, | |
), | |
gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=6000, | |
step=32, | |
value=args.image_size, | |
), | |
], | |
outputs=[ | |
Image(type="numpy", label="Img"), | |
Textbox(label="clean prompt"), | |
Textbox(label="model info"), | |
gr.Slider(label="seed"), | |
], | |
title=title, | |
description=DESCRIPTION, | |
examples=examples, | |
) | |
demo.launch(server_name="0.0.0.0", server_port=args.port, debug=True, share=True) | |