Spaces:
vilarin
/
Running on Zero

flux-labs / app.py
vilarin's picture
Create app.py
0cffd40 verified
raw
history blame
No virus
10.7 kB
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler, AutoencoderTiny
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from accelerate import Accelerator
from huggingface_hub import hf_hub_download
import spaces
import gradio as gr
import numpy as np
import torch
import time
import PIL
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo_id = "tianweiy/DMD2"
subfolder = "model/sdxl/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_checkpoint_model_019000"
filename = "pytorch_model.bin"
class ModelWrapper:
def __init__(self, model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator):
super().__init__()
torch.set_grad_enabled(False)
self.DTYPE = getattr(torch, precision)
self.device = accelerator.device
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").float().to(self.device)
self.vae_dtype = torch.float32
self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
self.tiny_vae_dtype = self.DTYPE
self.model = self.create_generator(model_id, checkpoint_path).to(dtype=self.DTYPE).to(self.device)
self.accelerator = accelerator
self.image_resolution = image_resolution
self.latent_resolution = latent_resolution
self.num_train_timesteps = num_train_timesteps
self.vae_downsample_ratio = image_resolution // latent_resolution
self.conditioning_timestep = conditioning_timestep
self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
self.num_step = num_step
def create_generator(self, model_id, checkpoint_path):
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
state_dict = torch.load(checkpoint_path, map_location="cpu")
generator.load_state_dict(state_dict, strict=True)
generator.requires_grad_(False)
return generator
def build_condition_input(self, height, width):
original_size = (height, width)
target_size = (height, width)
crop_top_left = (0, 0)
add_time_ids = list(original_size + crop_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], device=self.device, dtype=self.DTYPE)
return add_time_ids
def _encode_prompt(self, prompt):
text_input_ids_one = self.tokenizer_one([prompt], padding="max_length", max_length=self.tokenizer_one.model_max_length, truncation=True, return_tensors="pt").input_ids
text_input_ids_two = self.tokenizer_two([prompt], padding="max_length", max_length=self.tokenizer_two.model_max_length, truncation=True, return_tensors="pt").input_ids
prompt_dict = {
'text_input_ids_one': text_input_ids_one.unsqueeze(0).to(self.device),
'text_input_ids_two': text_input_ids_two.unsqueeze(0).to(self.device)
}
return prompt_dict
@staticmethod
def _get_time():
torch.cuda.synchronize()
return time.time()
def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
if self.num_step == 1:
all_timesteps = [self.conditioning_timestep]
step_interval = 0
elif self.num_step == 4:
all_timesteps = [999, 749, 499, 249]
step_interval = 250
else:
raise NotImplementedError()
DTYPE = prompt_embed.dtype
for constant in all_timesteps:
current_timesteps = torch.ones(len(prompt_embed), device=self.device, dtype=torch.long) * constant
eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps).to(self.DTYPE)
next_timestep = current_timesteps - step_interval
noise = self.scheduler.add_noise(eval_images, torch.randn_like(eval_images), next_timestep).to(DTYPE)
if fast_vae_decode:
eval_images = self.tiny_vae.decode(eval_images.to(self.tiny_vae_dtype) / self.tiny_vae.config.scaling_factor, return_dict=False)[0]
else:
eval_images = self.vae.decode(eval_images.to(self.vae_dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
return eval_images
@spaces.GPU(enable_queue=True)
@torch.no_grad()
def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
print("Running model inference...")
if seed == -1:
seed = np.random.randint(0, 1000000)
generator = torch.manual_seed(seed)
add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device=self.device, dtype=self.DTYPE)
prompt_inputs = self._encode_prompt(prompt)
start_time = self._get_time()
prompt_embeds, pooled_prompt_embeds = self.text_encoder(prompt_inputs)
batch_prompt_embeds, batch_pooled_prompt_embeds = (
prompt_embeds.repeat(num_images, 1, 1),
pooled_prompt_embeds.repeat(num_images, 1, 1)
)
unet_added_conditions = {
"time_ids": add_time_ids,
"text_embeds": batch_pooled_prompt_embeds.squeeze(1)
}
eval_images = self.sample(noise=noise, unet_added_conditions=unet_added_conditions, prompt_embed=batch_prompt_embeds, fast_vae_decode=fast_vae_decode)
end_time = self._get_time()
output_image_list = []
for image in eval_images:
output_image_list.append(PIL.Image.fromarray(image.cpu().numpy()))
return output_image_list, f"Run successfully in {(end_time-start_time):.2f} seconds"
def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
beta_prod_t = 1 - alpha_prod_t
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
return pred_original_sample
class SDXLTextEncoder(torch.nn.Module):
def __init__(self, model_id, revision, accelerator, dtype=torch.float32):
super().__init__()
self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(accelerator.device).to(dtype=dtype)
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", revision=revision).to(accelerator.device).to(dtype=dtype)
self.accelerator = accelerator
def forward(self, batch):
text_input_ids_one = batch['text_input_ids_one'].to(self.accelerator.device).squeeze(1)
text_input_ids_two = batch['text_input_ids_two'].to(self.accelerator.device).squeeze(1)
prompt_embeds_list = []
for text_input_ids, text_encoder in zip([text_input_ids_one, text_input_ids_two], [self.text_encoder_one, self.text_encoder_two]):
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(len(text_input_ids_one), -1)
return prompt_embeds, pooled_prompt_embeds
def create_demo():
TITLE = "# DMD2-SDXL Demo"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
checkpoint_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder,filename=filename)
precision = "float16"
image_resolution = 1024
latent_resolution = 128
num_train_timesteps = 1000
conditioning_timestep = 999
num_step = 4
revision = None
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
accelerator = Accelerator()
model = ModelWrapper(model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator)
with gr.Blocks() as demo:
gr.Markdown(TITLE)
with gr.Row():
with gr.Column():
prompt = gr.Text(value="An oil painting of two rabbits in the style of American Gothic, wearing the same clothes as in the original.", label="Prompt")
run_button = gr.Button("Run")
with gr.Accordion(label="Advanced options", open=True):
seed = gr.Slider(label="Seed", minimum=-1, maximum=1000000, step=1, value=0)
num_images = gr.Slider(label="Number of generated images", minimum=1, maximum=16, step=1, value=16)
fast_vae_decode = gr.Checkbox(label="Use Tiny VAE for faster decoding", value=True)
height = gr.Slider(label="Image Height", minimum=512, maximum=1536, step=64, value=1024)
width = gr.Slider(label="Image Width", minimum=512, maximum=1536, step=64, value=1024)
with gr.Column():
result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)
error_message = gr.Text(label="Job Status")
inputs = [prompt, seed, height, width, num_images, fast_vae_decode]
run_button.click(fn=model.inference, inputs=inputs, outputs=[result, error_message], concurrency_limit=1)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue(api_open=False)
demo.launch(show_error=True, share=True)