dreamgaussian4d / guidance /sd_utils.py
jiaweir
init
21c4e64
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
PNDMScheduler,
DDIMScheduler,
StableDiffusionPipeline,
)
from diffusers.utils.import_utils import is_xformers_available
# suppress partial model loading warning
logging.set_verbosity_error()
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
class StableDiffusion(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
sd_version="2.1",
hf_key=None,
t_range=[0.02, 0.98],
):
super().__init__()
self.device = device
self.sd_version = sd_version
if hf_key is not None:
print(f"[INFO] using hugging face custom model key: {hf_key}")
model_key = hf_key
elif self.sd_version == "2.1":
model_key = "stabilityai/stable-diffusion-2-1-base"
elif self.sd_version == "2.0":
model_key = "stabilityai/stable-diffusion-2-base"
elif self.sd_version == "1.5":
model_key = "runwayml/stable-diffusion-v1-5"
else:
raise ValueError(
f"Stable-diffusion version {self.sd_version} not supported."
)
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableDiffusionPipeline.from_pretrained(
model_key, torch_dtype=self.dtype
)
if vram_O:
pipe.enable_sequential_cpu_offload()
pipe.enable_vae_slicing()
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.vae = pipe.vae
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.scheduler = DDIMScheduler.from_pretrained(
model_key, subfolder="scheduler", torch_dtype=self.dtype
)
del pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = None
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
self.embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
def encode_text(self, prompt):
# prompt: [str]
inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
@torch.no_grad()
def refine(self, pred_rgb,
guidance_scale=100, steps=50, strength=0.8,
):
batch_size = pred_rgb.shape[0]
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
# latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)
self.scheduler.set_timesteps(steps)
init_step = int(steps * strength)
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
latent_model_input = torch.cat([latents] * 2)
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=self.embeddings,
).sample
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
return imgs
def train_step(
self,
pred_rgb,
step_ratio=None,
guidance_scale=100,
as_latent=False,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2)
tt = torch.cat([t] * 2)
noise_pred = self.unet(
latent_model_input, tt, encoder_hidden_states=self.embeddings.repeat(batch_size, 1, 1)
).sample
# perform guidance (high scale from paper!)
noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_pos - noise_pred_uncond
)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# seems important to avoid NaN...
# grad = grad.clamp(-1, 1)
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]
return loss
@torch.no_grad()
def produce_latents(
self,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
):
if latents is None:
latents = torch.randn(
(
self.embeddings.shape[0] // 2,
self.unet.in_channels,
height // 8,
width // 8,
),
device=self.device,
)
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=self.embeddings
).sample
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def prompt_to_img(
self,
prompts,
negative_prompts="",
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img latents
latents = self.produce_latents(
height=height,
width=width,
latents=latents,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
) # [1, 4, 64, 64]
# Img latents -> imgs
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype("uint8")
return imgs
if __name__ == "__main__":
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument(
"--sd_version",
type=str,
default="2.1",
choices=["1.5", "2.0", "2.1"],
help="stable diffusion version",
)
parser.add_argument(
"--hf_key",
type=str,
default=None,
help="hugging face Stable diffusion model key",
)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument(
"--vram_O", action="store_true", help="optimization for low VRAM usage"
)
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
seed_everything(opt.seed)
device = torch.device("cuda")
sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()