Spaces:
Sleeping
Sleeping
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
from torchvision.transforms import ToPILImage | |
from models.vqvae import VQVAEHF | |
from models.clip import FrozenCLIPEmbedder | |
from models.var import TVARHF, sample_with_top_k_top_p_, gumbel_softmax_with_rng | |
class TVARPipeline: | |
vae_path = "michellemoorre/vae-test" | |
text_encoder_path = "openai/clip-vit-large-patch14" | |
text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" | |
def __init__(self, var, vae, text_encoder, text_encoder_2, device): | |
self.var = var | |
self.vae = vae | |
self.text_encoder = text_encoder | |
self.text_encoder_2 = text_encoder_2 | |
self.var.eval() | |
self.vae.eval() | |
self.device = device | |
def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"): | |
var = TVARHF.from_pretrained(pretrained_model_name_or_path).to(device) | |
vae = VQVAEHF.from_pretrained(cls.vae_path).to(device) | |
text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device) | |
text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device) | |
return cls(var, vae, text_encoder, text_encoder_2, device) | |
def to_image(tensor): | |
return [ToPILImage()( | |
(255 * img.cpu().detach()).to(torch.uint8)) | |
for img in tensor] | |
def encode_prompt( | |
self, | |
prompt: Union[str, List[str]], | |
null_prompt: str = "", | |
encode_null: bool = True, | |
): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
encodings = [ | |
self.text_encoder.encode(prompt), | |
self.text_encoder_2.encode(prompt), | |
] | |
prompt_embeds = torch.concat( | |
[encoding.last_hidden_state for encoding in encodings], dim=-1 | |
) | |
pooled_prompt_embeds = encodings[-1].pooler_output | |
attn_bias = encodings[-1].attn_bias | |
if encode_null: | |
null_prompt = [null_prompt] if isinstance(null_prompt, str) else prompt | |
null_encodings = [ | |
self.text_encoder.encode(null_prompt), | |
self.text_encoder_2.encode(null_prompt), | |
] | |
null_prompt_embeds = torch.concat( | |
[encoding.last_hidden_state for encoding in encodings], dim=-1 | |
) | |
null_pooled_prompt_embeds = null_encodings[-1].pooler_output | |
null_attn_bias = null_encodings[-1].attn_bias | |
B, L, hidden_dim = prompt_embeds.shape | |
pooled_dim = pooled_prompt_embeds.shape[1] | |
null_prompt_embeds = null_prompt_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device) | |
null_pooled_prompt_embeds = null_pooled_prompt_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device) | |
null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device) | |
prompt_embeds = torch.cat([prompt_embeds, null_prompt_embeds], dim=0) | |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_prompt_embeds], dim=0) | |
attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0) | |
return prompt_embeds, pooled_prompt_embeds, attn_bias | |
def __call__( | |
self, | |
prompt = None, | |
null_prompt = "", | |
g_seed: Optional[int] = None, | |
cfg=4.0, | |
top_k=450, | |
top_p=0.95, | |
more_smooth=False, | |
re=False, | |
re_max_depth=10, | |
re_start_iter=2, | |
return_pil=True, | |
encoded_prompt = None, | |
encoded_null_prompt = None, | |
) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1] | |
""" | |
only used for inference, on autoregressive mode | |
:param B: batch size | |
:param label_B: imagenet label; if None, randomly sampled | |
:param g_seed: random seed | |
:param cfg: classifier-free guidance ratio | |
:param top_k: top-k sampling | |
:param top_p: top-p sampling | |
:param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking | |
:return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl | |
""" | |
assert not self.var.training | |
var = self.var | |
vae = self.vae | |
vae_quant = self.vae.quantize | |
if g_seed is None: | |
rng = None | |
else: | |
var.rng.manual_seed(g_seed) | |
rng = var.rng | |
if encoded_prompt is not None: | |
assert encoded_null_prompt is not None | |
context, cond_vector, context_attn_bias = self.var.parse_batch( | |
encoded_prompt, | |
encoded_null_prompt, | |
) | |
else: | |
context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt) | |
B = context.shape[0] // 2 | |
cond_vector = var.text_pooler(cond_vector) | |
sos = cond_BD = cond_vector | |
lvl_pos = var.lvl_embed(var.lvl_1L) | |
if not var.rope: | |
lvl_pos += var.pos_1LC | |
next_token_map = ( | |
sos.unsqueeze(1) | |
+ var.pos_start.expand(2 * B, var.first_l, -1) | |
+ lvl_pos[:, : var.first_l] | |
) | |
cur_L = 0 | |
f_hat = sos.new_zeros(B, var.Cvae, var.patch_nums[-1], var.patch_nums[-1]) | |
for b in var.blocks: | |
b.attn.kv_caching(True) | |
b.cross_attn.kv_caching(True) | |
for si, pn in enumerate(var.patch_nums): # si: i-th segment | |
ratio = si / var.num_stages_minus_1 | |
cond_BD_or_gss = var.shared_ada_lin(cond_BD) | |
x_BLC = next_token_map | |
if var.rope: | |
freqs_cis = var.freqs_cis[:, cur_L : cur_L + pn * pn] | |
else: | |
freqs_cis = var.freqs_cis | |
for block in var.blocks: | |
x_BLC = block( | |
x=x_BLC, | |
cond_BD=cond_BD_or_gss, | |
attn_bias=None, | |
context=context, | |
context_attn_bias=context_attn_bias, | |
freqs_cis=freqs_cis, | |
) | |
cur_L += pn * pn | |
logits_BlV = var.get_logits(x_BLC, cond_BD) | |
t = cfg * ratio | |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:] | |
idx_Bl = sample_with_top_k_top_p_( | |
logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1 | |
)[:, :, 0] | |
if re and si >= re_start_iter: | |
selected_logits = torch.gather(logits_BlV, -1, idx_Bl.unsqueeze(-1))[:, :, 0] | |
mx = selected_logits.sum(dim=-1)[:, None] | |
for _ in range(re_max_depth): | |
new_idx_Bl = sample_with_top_k_top_p_( | |
logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1 | |
)[:, :, 0] | |
selected_logits = torch.gather(logits_BlV, -1, new_idx_Bl.unsqueeze(-1))[:, :, 0] | |
new_mx = selected_logits.sum(dim=-1)[:, None] | |
idx_Bl = idx_Bl * (mx >= new_mx) + new_idx_Bl * (mx < new_mx) | |
mx = mx * (mx >= new_mx) + new_mx * (mx < new_mx) | |
if not more_smooth: # this is the default case | |
h_BChw = vae_quant.embedding(idx_Bl) # B, l, Cvae | |
else: # not used when evaluating FID/IS/Precision/Recall | |
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git | |
h_BChw = gumbel_softmax_with_rng( | |
logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng | |
) @ vae_quant.embedding.weight.unsqueeze(0) | |
h_BChw = h_BChw.transpose_(1, 2).reshape(B, var.Cvae, pn, pn) | |
f_hat, next_token_map = vae_quant.get_next_autoregressive_input( | |
si, len(var.patch_nums), f_hat, h_BChw | |
) | |
if si != var.num_stages_minus_1: # prepare for next stage | |
next_token_map = next_token_map.view(B, var.Cvae, -1).transpose(1, 2) | |
next_token_map = ( | |
var.word_embed(next_token_map) | |
+ lvl_pos[:, cur_L : cur_L + var.patch_nums[si + 1] ** 2] | |
) | |
next_token_map = next_token_map.repeat( | |
2, 1, 1 | |
) # double the batch sizes due to CFG | |
for b in var.blocks: | |
b.attn.kv_caching(False) | |
b.cross_attn.kv_caching(False) | |
# de-normalize, from [-1, 1] to [0, 1] | |
img = vae.fhat_to_img(f_hat).add(1).mul(0.5) | |
if return_pil: | |
img = self.to_image(img) | |
return img | |