from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torchvision.transforms import ToPILImage from models.vqvae import VQVAEHF from models.clip import FrozenCLIPEmbedder from models.var import TVARHF, sample_with_top_k_top_p_, gumbel_softmax_with_rng class TVARPipeline: vae_path = "michellemoorre/vae-test" text_encoder_path = "openai/clip-vit-large-patch14" text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" def __init__(self, var, vae, text_encoder, text_encoder_2, device): self.var = var self.vae = vae self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 self.var.eval() self.vae.eval() self.device = device @classmethod def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"): var = TVARHF.from_pretrained(pretrained_model_name_or_path).to(device) vae = VQVAEHF.from_pretrained(cls.vae_path).to(device) text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device) text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device) return cls(var, vae, text_encoder, text_encoder_2, device) @staticmethod def to_image(tensor): return [ToPILImage()( (255 * img.cpu().detach()).to(torch.uint8)) for img in tensor] def encode_prompt( self, prompt: Union[str, List[str]], null_prompt: str = "", encode_null: bool = True, ): prompt = [prompt] if isinstance(prompt, str) else prompt encodings = [ self.text_encoder.encode(prompt), self.text_encoder_2.encode(prompt), ] prompt_embeds = torch.concat( [encoding.last_hidden_state for encoding in encodings], dim=-1 ) pooled_prompt_embeds = encodings[-1].pooler_output attn_bias = encodings[-1].attn_bias if encode_null: null_prompt = [null_prompt] if isinstance(null_prompt, str) else prompt null_encodings = [ self.text_encoder.encode(null_prompt), self.text_encoder_2.encode(null_prompt), ] null_prompt_embeds = torch.concat( [encoding.last_hidden_state for encoding in encodings], dim=-1 ) null_pooled_prompt_embeds = null_encodings[-1].pooler_output null_attn_bias = null_encodings[-1].attn_bias B, L, hidden_dim = prompt_embeds.shape pooled_dim = pooled_prompt_embeds.shape[1] null_prompt_embeds = null_prompt_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device) null_pooled_prompt_embeds = null_pooled_prompt_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device) null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device) prompt_embeds = torch.cat([prompt_embeds, null_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_prompt_embeds], dim=0) attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0) return prompt_embeds, pooled_prompt_embeds, attn_bias @torch.inference_mode() def __call__( self, prompt = None, null_prompt = "", g_seed: Optional[int] = None, cfg=4.0, top_k=450, top_p=0.95, more_smooth=False, re=False, re_max_depth=10, re_start_iter=2, return_pil=True, encoded_prompt = None, encoded_null_prompt = None, ) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1] """ only used for inference, on autoregressive mode :param B: batch size :param label_B: imagenet label; if None, randomly sampled :param g_seed: random seed :param cfg: classifier-free guidance ratio :param top_k: top-k sampling :param top_p: top-p sampling :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking :return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl """ assert not self.var.training var = self.var vae = self.vae vae_quant = self.vae.quantize if g_seed is None: rng = None else: var.rng.manual_seed(g_seed) rng = var.rng if encoded_prompt is not None: assert encoded_null_prompt is not None context, cond_vector, context_attn_bias = self.var.parse_batch( encoded_prompt, encoded_null_prompt, ) else: context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt) B = context.shape[0] // 2 cond_vector = var.text_pooler(cond_vector) sos = cond_BD = cond_vector lvl_pos = var.lvl_embed(var.lvl_1L) if not var.rope: lvl_pos += var.pos_1LC next_token_map = ( sos.unsqueeze(1) + var.pos_start.expand(2 * B, var.first_l, -1) + lvl_pos[:, : var.first_l] ) cur_L = 0 f_hat = sos.new_zeros(B, var.Cvae, var.patch_nums[-1], var.patch_nums[-1]) for b in var.blocks: b.attn.kv_caching(True) b.cross_attn.kv_caching(True) for si, pn in enumerate(var.patch_nums): # si: i-th segment ratio = si / var.num_stages_minus_1 cond_BD_or_gss = var.shared_ada_lin(cond_BD) x_BLC = next_token_map if var.rope: freqs_cis = var.freqs_cis[:, cur_L : cur_L + pn * pn] else: freqs_cis = var.freqs_cis for block in var.blocks: x_BLC = block( x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=None, context=context, context_attn_bias=context_attn_bias, freqs_cis=freqs_cis, ) cur_L += pn * pn logits_BlV = var.get_logits(x_BLC, cond_BD) t = cfg * ratio logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:] idx_Bl = sample_with_top_k_top_p_( logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1 )[:, :, 0] if re and si >= re_start_iter: selected_logits = torch.gather(logits_BlV, -1, idx_Bl.unsqueeze(-1))[:, :, 0] mx = selected_logits.sum(dim=-1)[:, None] for _ in range(re_max_depth): new_idx_Bl = sample_with_top_k_top_p_( logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1 )[:, :, 0] selected_logits = torch.gather(logits_BlV, -1, new_idx_Bl.unsqueeze(-1))[:, :, 0] new_mx = selected_logits.sum(dim=-1)[:, None] idx_Bl = idx_Bl * (mx >= new_mx) + new_idx_Bl * (mx < new_mx) mx = mx * (mx >= new_mx) + new_mx * (mx < new_mx) if not more_smooth: # this is the default case h_BChw = vae_quant.embedding(idx_Bl) # B, l, Cvae else: # not used when evaluating FID/IS/Precision/Recall gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git h_BChw = gumbel_softmax_with_rng( logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng ) @ vae_quant.embedding.weight.unsqueeze(0) h_BChw = h_BChw.transpose_(1, 2).reshape(B, var.Cvae, pn, pn) f_hat, next_token_map = vae_quant.get_next_autoregressive_input( si, len(var.patch_nums), f_hat, h_BChw ) if si != var.num_stages_minus_1: # prepare for next stage next_token_map = next_token_map.view(B, var.Cvae, -1).transpose(1, 2) next_token_map = ( var.word_embed(next_token_map) + lvl_pos[:, cur_L : cur_L + var.patch_nums[si + 1] ** 2] ) next_token_map = next_token_map.repeat( 2, 1, 1 ) # double the batch sizes due to CFG for b in var.blocks: b.attn.kv_caching(False) b.cross_attn.kv_caching(False) # de-normalize, from [-1, 1] to [0, 1] img = vae.fhat_to_img(f_hat).add(1).mul(0.5) if return_pil: img = self.to_image(img) return img