import math from functools import partial from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin import dist from models.basic_var import AdaLNBeforeHead, AdaLNSelfCrossAttn from models.clip import FrozenCLIPEmbedder from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_ from models.rope import compute_axial_cis from models.vqvae import VQVAE, VectorQuantizer2 class SharedAdaLin(nn.Linear): def forward(self, cond_BD): C = self.weight.shape[0] // 6 return super().forward(cond_BD).view(-1, 1, 6, C) # B16C class VAR(nn.Module): def __init__( self, rope=False, rope_theta=100, rope_size=None, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4.0, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_eps=1e-6, shared_aln=False, attn_l2_norm=False, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default fused_if_available=True, use_swiglu_ffn=False, Cvae=32, V=4096 ): super().__init__() # 0. hyperparameters assert embed_dim % num_heads == 0 self.depth, self.C, self.D, self.num_heads = ( depth, embed_dim, embed_dim, num_heads, ) self.Cvae, self.V = Cvae, V self.prog_si = -1 # progressive training self.patch_nums: Tuple[int] = patch_nums self.L = sum(pn**2 for pn in self.patch_nums) self.first_l = self.patch_nums[0] ** 2 self.rope = rope self.num_stages_minus_1 = len(self.patch_nums) - 1 self.rng = torch.Generator(device=dist.get_device()) # 1. input (word) embedding self.word_embed = nn.Linear(self.Cvae, self.C) # 2. text embedding self.pooled_embed_size = 1280 context_dim = 1280 + 768 self.text_pooler = nn.Linear(self.pooled_embed_size, self.D) init_std = math.sqrt(1 / self.C / 3) self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) # 3. position embedding if not self.rope: # absolute position embedding pos_1LC = [] for i, pn in enumerate(self.patch_nums): pe = torch.empty(1, pn * pn, self.C) nn.init.trunc_normal_(pe, mean=0, std=init_std) pos_1LC.append(pe) pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C assert tuple(pos_1LC.shape) == (1, self.L, self.C) self.pos_1LC = nn.Parameter(pos_1LC) self.freqs_cis = None else: # RoPE position embedding assert ( self.C // self.num_heads ) % 4 == 0, "2d rope needs head dim to be divisible by 4" patch_nums_m1 = tuple(pn - 1 if pn > 1 else 1 for pn in self.patch_nums) self.compute_cis = partial(compute_axial_cis, dim=self.C // self.num_heads) freqs_cis = [] for i, pn in enumerate(self.patch_nums): norm_coeff = rope_size / patch_nums_m1[i] cur_freqs = self.compute_cis( end_x=pn, end_y=pn, theta=rope_theta, norm_coeff=norm_coeff ) freqs_cis.append(cur_freqs[None, ...]) self.freqs_cis = torch.cat(freqs_cis, dim=1) # 1, L, C // 2 -- complex # level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid) self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C) nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) # 4. backbone blocks self.shared_ada_lin = ( nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6 * self.C)) if shared_aln else nn.Identity() ) norm_layer = partial(nn.LayerNorm, eps=norm_eps) self.drop_path_rate = drop_path_rate # stochastic depth decay rule (linearly increasing) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] self.blocks = nn.ModuleList([]) for block_idx in range(depth): self.blocks.append( AdaLNSelfCrossAttn( cond_dim=self.D, shared_aln=shared_aln, block_idx=block_idx, embed_dim=self.C, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1], qk_norm=attn_l2_norm, context_dim=context_dim, use_swiglu_ffn=use_swiglu_ffn, norm_eps=norm_eps, ) ) fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks] self.using_fused_add_norm_fn = any(fused_add_norm_fns) print( f"\n[constructor] ==== fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n" f" [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n" f" [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})", end="\n\n", flush=True, ) # 5. attention mask used in training (for masking out the future) # it won't be used in inference, since kv cache is enabled d: torch.Tensor = torch.cat( [torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)] ).view(1, self.L, 1) dT = d.transpose(1, 2) # dT: 11L lvl_1L = dT[:, 0].contiguous() self.register_buffer("lvl_1L", lvl_1L) attn_bias_for_masking = torch.where(d >= dT, 0.0, -torch.inf).reshape( 1, 1, self.L, self.L ) self.register_buffer( "attn_bias_for_masking", attn_bias_for_masking.contiguous() ) # 6. classifier head self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer) self.head = nn.Linear(self.C, self.V) # By defailt disable gradient checkpointing self.use_gradient_checkpointing = False def enable_gradient_checkpointing(self): self.use_gradient_checkpointing = True def disable_gradient_checkpointing(self): self.use_gradient_checkpointing = False def get_logits( self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor], ): if not isinstance(h_or_h_and_residual, torch.Tensor): h, resi = h_or_h_and_residual # fused_add_norm must be used h = resi + self.blocks[-1].drop_path(h) else: # fused_add_norm is not used h = h_or_h_and_residual return self.head(self.head_nm(h.float(), cond_BD).float()).float() def parse_batch(self, batch, null_batch=None): embedding_1 = batch["vit_l_14_text_embeddings"] embedding_2 = batch["vit_bigg_14_text_embeddings"] attention_mask = batch["vit_bigg_14_text_mask"] batch_size = embedding_1.size(0) prompt_embed = torch.concat([embedding_1, embedding_2], dim=-1) prompt_lens = attention_mask.sum(dim=-1).to(int) pooled_output = embedding_2[ torch.arange(batch_size, device=embedding_2.device), prompt_lens - 1 ] attention_bias = attention_mask.clone() attention_bias[attention_mask == 0] = -float("inf") attention_bias[attention_mask == 1] = 0.0 if null_batch is not None: B, L, hidden_dim = prompt_embed.shape pooled_dim = pooled_output.shape[1] null_context = null_batch['prompt_embed'] null_pooled_embed = null_batch['pooled_embed'] null_attn_bias = null_batch['attn_bias'] null_context = null_context[:, :L].expand(B, L, hidden_dim).to(prompt_embed.device) null_pooled_embed = null_pooled_embed.expand(B, pooled_dim).to(pooled_output.device) null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attention_bias.device) prompt_embed = torch.cat([prompt_embed, null_context], dim=0) pooled_output = torch.cat([pooled_output, null_pooled_embed], dim=0) attention_bias = torch.cat([attention_bias, null_attn_bias], dim=0) return ( prompt_embed.to(dist.get_device()), pooled_output.to(dist.get_device()), attention_bias.to(dist.get_device()), ) def forward( self, x_BLCv_wo_first_l: torch.Tensor, prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor, prompt_attn_bias: torch.Tensor, ) -> torch.Tensor: # returns logits_BLV """ :param batch: {'image': not used in forward, 'text': image caption, 'vit_l_14_text_embeddings': text embedding from CLIP-ViT-L-14 'vit_bigg_14_text_embeddings': text embedding from CLIP-ViT-Big-G-14 'vit_bigg_14_text_mask': attention mask to get a correct pooled embedding :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae) :return: logits BLV, V is vocab_size """ bg, ed = 0, self.L B = x_BLCv_wo_first_l.shape[0] with torch.amp.autocast('cuda', enabled=False): pooled_prompt_embeds = self.text_pooler(pooled_prompt_embeds) sos = cond_BD = pooled_prompt_embeds sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand( B, self.first_l, -1 ) x_BLC = torch.cat( (sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1 ) x_BLC += self.lvl_embed( self.lvl_1L[:, :ed].expand(B, -1) ) # lvl: BLC; pos: 1LC if not self.rope: x_BLC += self.pos_1LC[:, :ed] attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed] cond_BD_or_gss = self.shared_ada_lin(cond_BD) # hack: get the dtype if mixed precision is used temp = x_BLC.new_ones(8, 8) main_type = torch.matmul(temp, temp).dtype x_BLC = x_BLC.to(dtype=main_type) cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type) attn_bias = attn_bias.to(dtype=main_type) for block in self.blocks: if self.use_gradient_checkpointing: x_BLC = torch.utils.checkpoint.checkpoint( block, x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias, context=prompt_embeds, freqs_cis=self.freqs_cis, context_attn_bias=prompt_attn_bias, use_reentrant=False, ) else: x_BLC = block( x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias, context=prompt_embeds, freqs_cis=self.freqs_cis, context_attn_bias=prompt_attn_bias, ) with torch.amp.autocast('cuda', enabled=not self.training): x_BLC = self.get_logits(x_BLC.float(), cond_BD) return x_BLC # logits BLV, V is vocab_size def init_weights( self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, ): if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated print(f"[init_weights] {type(self).__name__} with {init_std=:g}") for m in self.modules(): with_weight = hasattr(m, "weight") and m.weight is not None with_bias = hasattr(m, "bias") and m.bias is not None if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight.data, std=init_std) if with_bias: m.bias.data.zero_() elif isinstance(m, nn.Embedding): nn.init.trunc_normal_(m.weight.data, std=init_std) if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_() elif isinstance( m, ( nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, ), ): if with_weight: m.weight.data.fill_(1.0) if with_bias: m.bias.data.zero_() if init_head >= 0: if isinstance(self.head, nn.Linear): self.head.weight.data.mul_(init_head) self.head.bias.data.zero_() elif isinstance(self.head, nn.Sequential): self.head[-1].weight.data.mul_(init_head) self.head[-1].bias.data.zero_() if isinstance(self.head_nm, AdaLNBeforeHead): self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln) if ( hasattr(self.head_nm.ada_lin[-1], "bias") and self.head_nm.ada_lin[-1].bias is not None ): self.head_nm.ada_lin[-1].bias.data.zero_() depth = len(self.blocks) for block in self.blocks: block.attn.proj.weight.data.div_(math.sqrt(2 * depth)) block.cross_attn.proj.weight.data.div_(math.sqrt(2 * depth)) if hasattr(block.ffn, "fc2"): block.ffn.fc2.weight.data.div_(math.sqrt(2 * depth)) if hasattr(block, "ada_lin"): block.ada_lin[-1].weight.data[2 * self.C :].mul_(init_adaln) block.ada_lin[-1].weight.data[: 2 * self.C].mul_(init_adaln_gamma) if ( hasattr(block.ada_lin[-1], "bias") and block.ada_lin[-1].bias is not None ): block.ada_lin[-1].bias.data.zero_() elif hasattr(block, "ada_gss"): block.ada_gss.data[:, :, 2:].mul_(init_adaln) block.ada_gss.data[:, :, :2].mul_(init_adaln_gamma) def extra_repr(self): return f"drop_path_rate={self.drop_path_rate:g}" class TVARHF(VAR, PyTorchModelHubMixin): # tags=["image-generation"]): def __init__( self, depth=30, shared_aln=False, attn_l2_norm=True, rope=True, rope_theta=10000, rope_size=128, use_swiglu_ffn=True, ): heads = depth width = depth * 64 super().__init__( depth=depth, embed_dim=width, num_heads=heads, drop_rate=0.0, attn_drop_rate=0.0, norm_eps=1e-6, shared_aln=shared_aln, attn_l2_norm=attn_l2_norm, patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32), rope=rope, rope_theta=rope_theta, rope_size=rope_size, use_swiglu_ffn=use_swiglu_ffn, )