|
import torch |
|
import copy |
|
import random |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): |
|
prompt_embeds_list = [] |
|
|
|
captions = [] |
|
for caption in prompt_batch: |
|
if random.random() < proportion_empty_prompts: |
|
captions.append("") |
|
elif isinstance(caption, str): |
|
captions.append(caption) |
|
elif isinstance(caption, (list, np.ndarray)): |
|
|
|
captions.append(random.choice(caption) if is_train else caption[0]) |
|
|
|
with torch.no_grad(): |
|
for tokenizer, text_encoder in zip(tokenizers, text_encoders): |
|
text_inputs = tokenizer( |
|
captions, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
prompt_embeds = text_encoder( |
|
text_input_ids.to(text_encoder.device), |
|
output_hidden_states=True, |
|
) |
|
|
|
|
|
pooled_prompt_embeds = prompt_embeds[0] |
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
|
prompt_embeds_list.append(prompt_embeds) |
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
|
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) |
|
return prompt_embeds, pooled_prompt_embeds |
|
|
|
|
|
def compute_embeddings( |
|
prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True, |
|
device='cuda' |
|
): |
|
target_size = (1024, 1024) |
|
original_sizes = original_sizes |
|
crops_coords_top_left = crop_coords |
|
|
|
original_sizes = torch.tensor(original_sizes, dtype=torch.long) |
|
crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long) |
|
|
|
prompt_embeds, pooled_prompt_embeds = encode_prompt( |
|
prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train |
|
) |
|
add_text_embeds = pooled_prompt_embeds |
|
|
|
|
|
add_time_ids = list(target_size) |
|
add_time_ids = torch.tensor([add_time_ids]) |
|
add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) |
|
add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1) |
|
add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype) |
|
|
|
prompt_embeds = prompt_embeds.to(device) |
|
add_text_embeds = add_text_embeds.to(device) |
|
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} |
|
|
|
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} |
|
|
|
def extract_into_tensor(a, t, x_shape): |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): |
|
""" |
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
|
|
|
Args: |
|
timesteps (`torch.Tensor`): |
|
generate embedding vectors at these timesteps |
|
embedding_dim (`int`, *optional*, defaults to 512): |
|
dimension of the embeddings to generate |
|
dtype: |
|
data type of the generated embeddings |
|
|
|
Returns: |
|
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` |
|
""" |
|
assert len(w.shape) == 1 |
|
w = w * 1000.0 |
|
|
|
half_dim = embedding_dim // 2 |
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
|
emb = w.to(dtype)[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1)) |
|
assert emb.shape == (w.shape[0], embedding_dim) |
|
return emb |
|
|
|
def predicted_origin(model_output, timesteps, boundary_timesteps, sample, prediction_type, alphas, sigmas): |
|
sigmas_s = extract_into_tensor(sigmas, boundary_timesteps, sample.shape) |
|
alphas_s = extract_into_tensor(alphas, boundary_timesteps, sample.shape) |
|
|
|
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
|
alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
|
|
|
|
|
alphas_s[boundary_timesteps == 0] = 1.0 |
|
sigmas_s[boundary_timesteps == 0] = 0.0 |
|
|
|
if prediction_type == "epsilon": |
|
pred_x_0 = (sample - sigmas * model_output) / alphas |
|
pred_x_0 = alphas_s * pred_x_0 + sigmas_s * model_output |
|
elif prediction_type == "v_prediction": |
|
assert boundary_timesteps == 0, "v_prediction does not support multiple endpoints at the moment" |
|
pred_x_0 = alphas * sample - sigmas * model_output |
|
else: |
|
raise ValueError(f"Prediction type {prediction_type} currently not supported.") |
|
|
|
return pred_x_0 |
|
|
|
|
|
class DDIMSolver: |
|
def __init__( |
|
self, alpha_cumprods, timesteps=1000, ddim_timesteps=50, |
|
num_endpoints=1, num_inverse_endpoints=1, |
|
max_inverse_timestep_index=49, |
|
endpoints=None, inverse_endpoints=None |
|
): |
|
|
|
step_ratio = timesteps // ddim_timesteps |
|
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype( |
|
np.int64) - 1 |
|
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] |
|
self.ddim_alpha_cumprods_prev = np.asarray( |
|
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() |
|
) |
|
self.ddim_alpha_cumprods_next = np.asarray( |
|
alpha_cumprods[self.ddim_timesteps[1:]].tolist() + [0.0] |
|
) |
|
|
|
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() |
|
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) |
|
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) |
|
self.ddim_alpha_cumprods_next = torch.from_numpy(self.ddim_alpha_cumprods_next) |
|
|
|
|
|
if endpoints is None: |
|
timestep_interval = ddim_timesteps // num_endpoints + int(ddim_timesteps % num_endpoints > 0) |
|
endpoint_idxs = torch.arange(timestep_interval, ddim_timesteps, timestep_interval) - 1 |
|
self.endpoints = torch.tensor([0] + self.ddim_timesteps[endpoint_idxs].tolist()) |
|
else: |
|
self.endpoints = torch.tensor([int(endpoint) for endpoint in endpoints.split(',')]) |
|
assert len(self.endpoints) == num_endpoints |
|
|
|
|
|
if inverse_endpoints is None: |
|
timestep_interval = ddim_timesteps // num_inverse_endpoints + int( |
|
ddim_timesteps % num_inverse_endpoints > 0) |
|
inverse_endpoint_idxs = torch.arange(timestep_interval, ddim_timesteps, timestep_interval) - 1 |
|
inverse_endpoint_idxs = torch.tensor(inverse_endpoint_idxs.tolist() + [max_inverse_timestep_index]) |
|
self.inverse_endpoints = self.ddim_timesteps[inverse_endpoint_idxs] |
|
else: |
|
self.inverse_endpoints = torch.tensor([int(endpoint) for endpoint in inverse_endpoints.split(',')]) |
|
assert len(self.inverse_endpoints) == num_inverse_endpoints |
|
|
|
def to(self, device): |
|
self.endpoints = self.endpoints.to(device) |
|
self.inverse_endpoints = self.inverse_endpoints.to(device) |
|
|
|
self.ddim_timesteps = self.ddim_timesteps.to(device) |
|
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) |
|
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) |
|
self.ddim_alpha_cumprods_next = self.ddim_alpha_cumprods_next.to(device) |
|
return self |
|
|
|
def ddim_step(self, pred_x0, pred_noise, timestep_index): |
|
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) |
|
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise |
|
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt |
|
return x_prev |
|
|
|
def inverse_ddim_step(self, pred_x0, pred_noise, timestep_index): |
|
alpha_cumprod_next = extract_into_tensor(self.ddim_alpha_cumprods_next, timestep_index, pred_x0.shape) |
|
dir_xt = (1.0 - alpha_cumprod_next).sqrt() * pred_noise |
|
x_next = alpha_cumprod_next.sqrt() * pred_x0 + dir_xt |
|
return x_next |
|
|
|
|
|
|
|
|
|
def inverse_sample_deterministic( |
|
pipe, |
|
images, |
|
prompt, |
|
generator=None, |
|
num_scales=50, |
|
num_inference_steps=1, |
|
timesteps=None, |
|
start_timestep=19, |
|
max_inverse_timestep_index=49, |
|
return_start_latent=False, |
|
guidance_scale=None, |
|
compute_embeddings_fn=None, |
|
is_sdxl=False, |
|
inverse_endpoints=None, |
|
seed=0, |
|
): |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
|
|
device = pipe._execution_device |
|
|
|
|
|
if compute_embeddings_fn is not None: |
|
if is_sdxl: |
|
orig_size = [(1024, 1024)] * len(prompt) |
|
crop_coords = [(0, 0)] * len(prompt) |
|
encoded_text = compute_embeddings_fn(prompt, orig_size, crop_coords) |
|
prompt_embeds = encoded_text.pop("prompt_embeds") |
|
else: |
|
prompt_embeds = compute_embeddings_fn(prompt)["prompt_embeds"] |
|
encoded_text = {} |
|
prompt_embeds = prompt_embeds.to(pipe.unet.dtype) |
|
else: |
|
prompt_embeds = pipe.encode_prompt(prompt, device, 1, False)[0] |
|
encoded_text = {} |
|
assert prompt_embeds.dtype == pipe.unet.dtype |
|
|
|
|
|
endpoints = ','.join(['0'] + inverse_endpoints.split(',')[:-1]) if inverse_endpoints is not None else None |
|
solver = DDIMSolver( |
|
pipe.scheduler.alphas_cumprod.cpu().numpy(), |
|
timesteps=pipe.scheduler.num_train_timesteps, |
|
ddim_timesteps=num_scales, |
|
num_endpoints=num_inference_steps, |
|
num_inverse_endpoints=num_inference_steps, |
|
max_inverse_timestep_index=max_inverse_timestep_index, |
|
endpoints=endpoints, |
|
inverse_endpoints=inverse_endpoints |
|
).to(device) |
|
|
|
if timesteps is None: |
|
timesteps = solver.inverse_endpoints.flip(0) |
|
boundary_timesteps = solver.endpoints.flip(0) |
|
else: |
|
timesteps, boundary_timesteps = timesteps, timesteps |
|
boundary_timesteps = boundary_timesteps[1:] + [boundary_timesteps[0]] |
|
boundary_timesteps[-1] = 999 |
|
timesteps, boundary_timesteps = torch.tensor(timesteps), torch.tensor(boundary_timesteps) |
|
|
|
alpha_schedule = torch.sqrt(pipe.scheduler.alphas_cumprod).to(device) |
|
sigma_schedule = torch.sqrt(1 - pipe.scheduler.alphas_cumprod).to(device) |
|
|
|
|
|
num_channels_latents = pipe.unet.config.in_channels |
|
start_latents = pipe.prepare_latents( |
|
images, timesteps[0], batch_size, 1, prompt_embeds.dtype, device, |
|
generator=torch.Generator().manual_seed(seed), |
|
) |
|
latents = start_latents.clone() |
|
|
|
if guidance_scale is not None: |
|
w = torch.ones(batch_size) * guidance_scale |
|
w_embedding = guidance_scale_embedding(w, embedding_dim=512) |
|
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) |
|
else: |
|
w_embedding = None |
|
|
|
for i, (t, s) in enumerate(zip(timesteps, boundary_timesteps)): |
|
|
|
noise_pred = pipe.unet( |
|
latents.to(prompt_embeds.dtype), |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
return_dict=False, |
|
timestep_cond=w_embedding, |
|
added_cond_kwargs=encoded_text, |
|
)[0] |
|
|
|
latents = predicted_origin( |
|
noise_pred, |
|
torch.tensor([t] * len(latents), device=device), |
|
torch.tensor([s] * len(latents), device=device), |
|
latents, |
|
pipe.scheduler.config.prediction_type, |
|
alpha_schedule, |
|
sigma_schedule, |
|
).to(prompt_embeds.dtype) |
|
|
|
if return_start_latent: |
|
return latents, start_latents |
|
else: |
|
return latents |
|
|
|
|
|
def linear_schedule_old(t, guidance_scale, tau1, tau2): |
|
t = t / 1000 |
|
if t <= tau1: |
|
gamma = 1.0 |
|
elif t >= tau2: |
|
gamma = 0.0 |
|
else: |
|
gamma = (tau2 - t) / (tau2 - tau1) |
|
return gamma * guidance_scale |
|
|
|
|
|
@torch.no_grad() |
|
def sample_deterministic( |
|
pipe, |
|
prompt, |
|
latents=None, |
|
generator=None, |
|
num_scales=50, |
|
num_inference_steps=1, |
|
timesteps=None, |
|
start_timestep=19, |
|
max_inverse_timestep_index=49, |
|
return_latent=False, |
|
guidance_scale=None, |
|
compute_embeddings_fn=None, |
|
is_sdxl=False, |
|
endpoints=None, |
|
use_dynamic_guidance=False, |
|
tau1=0.7, |
|
tau2=0.7, |
|
amplify_prompt=None, |
|
): |
|
|
|
height = pipe.unet.config.sample_size * pipe.vae_scale_factor |
|
width = pipe.unet.config.sample_size * pipe.vae_scale_factor |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
|
|
device = pipe._execution_device |
|
|
|
|
|
if compute_embeddings_fn is not None: |
|
if is_sdxl: |
|
orig_size = [(1024, 1024)] * len(prompt) |
|
crop_coords = [(0, 0)] * len(prompt) |
|
encoded_text = compute_embeddings_fn(prompt, orig_size, crop_coords) |
|
prompt_embeds = encoded_text.pop("prompt_embeds") |
|
if amplify_prompt is not None: |
|
orig_size = [(1024, 1024)] * len(amplify_prompt) |
|
crop_coords = [(0, 0)] * len(amplify_prompt) |
|
encoded_text_old = compute_embeddings_fn(amplify_prompt, orig_size, crop_coords) |
|
amplify_prompt_embeds = encoded_text_old.pop("prompt_embeds") |
|
else: |
|
prompt_embeds = compute_embeddings_fn(prompt)["prompt_embeds"] |
|
encoded_text = {} |
|
prompt_embeds = prompt_embeds.to(pipe.unet.dtype) |
|
else: |
|
prompt_embeds = pipe.encode_prompt(prompt, device, 1, False)[0] |
|
encoded_text = {} |
|
assert prompt_embeds.dtype == pipe.unet.dtype |
|
|
|
|
|
inverse_endpoints = ','.join(endpoints.split(',')[1:] + ['999']) if endpoints is not None else None |
|
solver = DDIMSolver( |
|
pipe.scheduler.alphas_cumprod.numpy(), |
|
timesteps=pipe.scheduler.num_train_timesteps, |
|
ddim_timesteps=num_scales, |
|
num_endpoints=num_inference_steps, |
|
num_inverse_endpoints=num_inference_steps, |
|
max_inverse_timestep_index=max_inverse_timestep_index, |
|
endpoints=endpoints, |
|
inverse_endpoints=inverse_endpoints |
|
).to(device) |
|
|
|
prompt_embeds_init = copy.deepcopy(prompt_embeds) |
|
|
|
if timesteps is None: |
|
timesteps = solver.inverse_endpoints.flip(0) |
|
boundary_timesteps = solver.endpoints.flip(0) |
|
else: |
|
timesteps, boundary_timesteps = copy.deepcopy(timesteps), copy.deepcopy(timesteps) |
|
timesteps.reverse() |
|
boundary_timesteps.reverse() |
|
boundary_timesteps = boundary_timesteps[1:] + [boundary_timesteps[0]] |
|
boundary_timesteps[-1] = 0 |
|
timesteps, boundary_timesteps = torch.tensor(timesteps), torch.tensor(boundary_timesteps) |
|
|
|
alpha_schedule = torch.sqrt(pipe.scheduler.alphas_cumprod).to(device) |
|
sigma_schedule = torch.sqrt(1 - pipe.scheduler.alphas_cumprod).to(device) |
|
|
|
|
|
if latents is None: |
|
num_channels_latents = pipe.unet.config.in_channels |
|
latents = pipe.prepare_latents( |
|
batch_size, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
None, |
|
) |
|
assert latents.dtype == pipe.unet.dtype |
|
else: |
|
latents = latents.to(prompt_embeds.dtype) |
|
|
|
if guidance_scale is not None: |
|
w = torch.ones(batch_size) * guidance_scale |
|
w_embedding = guidance_scale_embedding(w, embedding_dim=512) |
|
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) |
|
else: |
|
w_embedding = None |
|
|
|
for i, (t, s) in enumerate(zip(timesteps, boundary_timesteps)): |
|
if use_dynamic_guidance: |
|
if not isinstance(t, int): |
|
t_item = t.item() |
|
if t_item > tau1 * 1000 and amplify_prompt is not None: |
|
prompt_embeds = amplify_prompt_embeds |
|
else: |
|
prompt_embeds = prompt_embeds_init |
|
guidance_scale = linear_schedule_old(t_item, w, tau1=tau1, tau2=tau2) |
|
guidance_scale_tensor = torch.tensor([guidance_scale] * len(latents)) |
|
w_embedding = guidance_scale_embedding(guidance_scale_tensor, embedding_dim=512) |
|
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) |
|
|
|
|
|
noise_pred = pipe.unet( |
|
latents, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=None, |
|
return_dict=False, |
|
timestep_cond=w_embedding, |
|
added_cond_kwargs=encoded_text, |
|
)[0] |
|
|
|
latents = predicted_origin( |
|
noise_pred, |
|
torch.tensor([t] * len(noise_pred)).to(device), |
|
torch.tensor([s] * len(noise_pred)).to(device), |
|
latents, |
|
pipe.scheduler.config.prediction_type, |
|
alpha_schedule, |
|
sigma_schedule, |
|
).to(pipe.unet.dtype) |
|
|
|
pipe.vae.to(torch.float32) |
|
image = pipe.vae.decode(latents.to(torch.float32) / pipe.vae.config.scaling_factor, return_dict=False)[0] |
|
do_denormalize = [True] * image.shape[0] |
|
image = pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize) |
|
|
|
if return_latent: |
|
return image, latents |
|
else: |
|
return image |
|
|