michellemoorre's picture
Initial commit
6c4dee3
import math
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import dist
from models.basic_var import AdaLNBeforeHead, AdaLNSelfCrossAttn
from models.clip import FrozenCLIPEmbedder
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
from models.rope import compute_axial_cis
from models.vqvae import VQVAE, VectorQuantizer2
class SharedAdaLin(nn.Linear):
def forward(self, cond_BD):
C = self.weight.shape[0] // 6
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
class VAR(nn.Module):
def __init__(
self,
rope=False,
rope_theta=100,
rope_size=None,
depth=16,
embed_dim=1024,
num_heads=16,
mlp_ratio=4.0,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_eps=1e-6,
shared_aln=False,
attn_l2_norm=False,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
fused_if_available=True,
use_swiglu_ffn=False,
Cvae=32,
V=4096
):
super().__init__()
# 0. hyperparameters
assert embed_dim % num_heads == 0
self.depth, self.C, self.D, self.num_heads = (
depth,
embed_dim,
embed_dim,
num_heads,
)
self.Cvae, self.V = Cvae, V
self.prog_si = -1 # progressive training
self.patch_nums: Tuple[int] = patch_nums
self.L = sum(pn**2 for pn in self.patch_nums)
self.first_l = self.patch_nums[0] ** 2
self.rope = rope
self.num_stages_minus_1 = len(self.patch_nums) - 1
self.rng = torch.Generator(device=dist.get_device())
# 1. input (word) embedding
self.word_embed = nn.Linear(self.Cvae, self.C)
# 2. text embedding
self.pooled_embed_size = 1280
context_dim = 1280 + 768
self.text_pooler = nn.Linear(self.pooled_embed_size, self.D)
init_std = math.sqrt(1 / self.C / 3)
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
# 3. position embedding
if not self.rope:
# absolute position embedding
pos_1LC = []
for i, pn in enumerate(self.patch_nums):
pe = torch.empty(1, pn * pn, self.C)
nn.init.trunc_normal_(pe, mean=0, std=init_std)
pos_1LC.append(pe)
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
self.pos_1LC = nn.Parameter(pos_1LC)
self.freqs_cis = None
else:
# RoPE position embedding
assert (
self.C // self.num_heads
) % 4 == 0, "2d rope needs head dim to be divisible by 4"
patch_nums_m1 = tuple(pn - 1 if pn > 1 else 1 for pn in self.patch_nums)
self.compute_cis = partial(compute_axial_cis, dim=self.C // self.num_heads)
freqs_cis = []
for i, pn in enumerate(self.patch_nums):
norm_coeff = rope_size / patch_nums_m1[i]
cur_freqs = self.compute_cis(
end_x=pn, end_y=pn, theta=rope_theta, norm_coeff=norm_coeff
)
freqs_cis.append(cur_freqs[None, ...])
self.freqs_cis = torch.cat(freqs_cis, dim=1) # 1, L, C // 2 -- complex
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
# 4. backbone blocks
self.shared_ada_lin = (
nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6 * self.C))
if shared_aln
else nn.Identity()
)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
self.drop_path_rate = drop_path_rate
# stochastic depth decay rule (linearly increasing)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([])
for block_idx in range(depth):
self.blocks.append(
AdaLNSelfCrossAttn(
cond_dim=self.D,
shared_aln=shared_aln,
block_idx=block_idx,
embed_dim=self.C,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[block_idx],
last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
qk_norm=attn_l2_norm,
context_dim=context_dim,
use_swiglu_ffn=use_swiglu_ffn,
norm_eps=norm_eps,
)
)
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
print(
f"\n[constructor] ==== fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n"
f" [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n"
f" [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})",
end="\n\n",
flush=True,
)
# 5. attention mask used in training (for masking out the future)
# it won't be used in inference, since kv cache is enabled
d: torch.Tensor = torch.cat(
[torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]
).view(1, self.L, 1)
dT = d.transpose(1, 2) # dT: 11L
lvl_1L = dT[:, 0].contiguous()
self.register_buffer("lvl_1L", lvl_1L)
attn_bias_for_masking = torch.where(d >= dT, 0.0, -torch.inf).reshape(
1, 1, self.L, self.L
)
self.register_buffer(
"attn_bias_for_masking", attn_bias_for_masking.contiguous()
)
# 6. classifier head
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
self.head = nn.Linear(self.C, self.V)
# By defailt disable gradient checkpointing
self.use_gradient_checkpointing = False
def enable_gradient_checkpointing(self):
self.use_gradient_checkpointing = True
def disable_gradient_checkpointing(self):
self.use_gradient_checkpointing = False
def get_logits(
self,
h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
cond_BD: Optional[torch.Tensor],
):
if not isinstance(h_or_h_and_residual, torch.Tensor):
h, resi = h_or_h_and_residual # fused_add_norm must be used
h = resi + self.blocks[-1].drop_path(h)
else: # fused_add_norm is not used
h = h_or_h_and_residual
return self.head(self.head_nm(h.float(), cond_BD).float()).float()
def parse_batch(self, batch, null_batch=None):
embedding_1 = batch["vit_l_14_text_embeddings"]
embedding_2 = batch["vit_bigg_14_text_embeddings"]
attention_mask = batch["vit_bigg_14_text_mask"]
batch_size = embedding_1.size(0)
prompt_embed = torch.concat([embedding_1, embedding_2], dim=-1)
prompt_lens = attention_mask.sum(dim=-1).to(int)
pooled_output = embedding_2[
torch.arange(batch_size, device=embedding_2.device), prompt_lens - 1
]
attention_bias = attention_mask.clone()
attention_bias[attention_mask == 0] = -float("inf")
attention_bias[attention_mask == 1] = 0.0
if null_batch is not None:
B, L, hidden_dim = prompt_embed.shape
pooled_dim = pooled_output.shape[1]
null_context = null_batch['prompt_embed']
null_pooled_embed = null_batch['pooled_embed']
null_attn_bias = null_batch['attn_bias']
null_context = null_context[:, :L].expand(B, L, hidden_dim).to(prompt_embed.device)
null_pooled_embed = null_pooled_embed.expand(B, pooled_dim).to(pooled_output.device)
null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attention_bias.device)
prompt_embed = torch.cat([prompt_embed, null_context], dim=0)
pooled_output = torch.cat([pooled_output, null_pooled_embed], dim=0)
attention_bias = torch.cat([attention_bias, null_attn_bias], dim=0)
return (
prompt_embed.to(dist.get_device()),
pooled_output.to(dist.get_device()),
attention_bias.to(dist.get_device()),
)
def forward(
self,
x_BLCv_wo_first_l: torch.Tensor,
prompt_embeds: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
prompt_attn_bias: torch.Tensor,
) -> torch.Tensor: # returns logits_BLV
"""
:param batch: {'image': not used in forward,
'text': image caption,
'vit_l_14_text_embeddings': text embedding from CLIP-ViT-L-14
'vit_bigg_14_text_embeddings': text embedding from CLIP-ViT-Big-G-14
'vit_bigg_14_text_mask': attention mask to get a correct pooled embedding
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
:return: logits BLV, V is vocab_size
"""
bg, ed = 0, self.L
B = x_BLCv_wo_first_l.shape[0]
with torch.amp.autocast('cuda', enabled=False):
pooled_prompt_embeds = self.text_pooler(pooled_prompt_embeds)
sos = cond_BD = pooled_prompt_embeds
sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(
B, self.first_l, -1
)
x_BLC = torch.cat(
(sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1
)
x_BLC += self.lvl_embed(
self.lvl_1L[:, :ed].expand(B, -1)
) # lvl: BLC; pos: 1LC
if not self.rope:
x_BLC += self.pos_1LC[:, :ed]
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
# hack: get the dtype if mixed precision is used
temp = x_BLC.new_ones(8, 8)
main_type = torch.matmul(temp, temp).dtype
x_BLC = x_BLC.to(dtype=main_type)
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
attn_bias = attn_bias.to(dtype=main_type)
for block in self.blocks:
if self.use_gradient_checkpointing:
x_BLC = torch.utils.checkpoint.checkpoint(
block,
x=x_BLC,
cond_BD=cond_BD_or_gss,
attn_bias=attn_bias,
context=prompt_embeds,
freqs_cis=self.freqs_cis,
context_attn_bias=prompt_attn_bias,
use_reentrant=False,
)
else:
x_BLC = block(
x=x_BLC,
cond_BD=cond_BD_or_gss,
attn_bias=attn_bias,
context=prompt_embeds,
freqs_cis=self.freqs_cis,
context_attn_bias=prompt_attn_bias,
)
with torch.amp.autocast('cuda', enabled=not self.training):
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
return x_BLC # logits BLV, V is vocab_size
def init_weights(
self,
init_adaln=0.5,
init_adaln_gamma=1e-5,
init_head=0.02,
init_std=0.02,
):
if init_std < 0:
init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
print(f"[init_weights] {type(self).__name__} with {init_std=:g}")
for m in self.modules():
with_weight = hasattr(m, "weight") and m.weight is not None
with_bias = hasattr(m, "bias") and m.bias is not None
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if with_bias:
m.bias.data.zero_()
elif isinstance(m, nn.Embedding):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
elif isinstance(
m,
(
nn.LayerNorm,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.SyncBatchNorm,
nn.GroupNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
),
):
if with_weight:
m.weight.data.fill_(1.0)
if with_bias:
m.bias.data.zero_()
if init_head >= 0:
if isinstance(self.head, nn.Linear):
self.head.weight.data.mul_(init_head)
self.head.bias.data.zero_()
elif isinstance(self.head, nn.Sequential):
self.head[-1].weight.data.mul_(init_head)
self.head[-1].bias.data.zero_()
if isinstance(self.head_nm, AdaLNBeforeHead):
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
if (
hasattr(self.head_nm.ada_lin[-1], "bias")
and self.head_nm.ada_lin[-1].bias is not None
):
self.head_nm.ada_lin[-1].bias.data.zero_()
depth = len(self.blocks)
for block in self.blocks:
block.attn.proj.weight.data.div_(math.sqrt(2 * depth))
block.cross_attn.proj.weight.data.div_(math.sqrt(2 * depth))
if hasattr(block.ffn, "fc2"):
block.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
if hasattr(block, "ada_lin"):
block.ada_lin[-1].weight.data[2 * self.C :].mul_(init_adaln)
block.ada_lin[-1].weight.data[: 2 * self.C].mul_(init_adaln_gamma)
if (
hasattr(block.ada_lin[-1], "bias")
and block.ada_lin[-1].bias is not None
):
block.ada_lin[-1].bias.data.zero_()
elif hasattr(block, "ada_gss"):
block.ada_gss.data[:, :, 2:].mul_(init_adaln)
block.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
def extra_repr(self):
return f"drop_path_rate={self.drop_path_rate:g}"
class TVARHF(VAR, PyTorchModelHubMixin):
# tags=["image-generation"]):
def __init__(
self,
depth=30,
shared_aln=False,
attn_l2_norm=True,
rope=True,
rope_theta=10000,
rope_size=128,
use_swiglu_ffn=True,
):
heads = depth
width = depth * 64
super().__init__(
depth=depth,
embed_dim=width,
num_heads=heads,
drop_rate=0.0,
attn_drop_rate=0.0,
norm_eps=1e-6,
shared_aln=shared_aln,
attn_l2_norm=attn_l2_norm,
patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
rope=rope,
rope_theta=rope_theta,
rope_size=rope_size,
use_swiglu_ffn=use_swiglu_ffn,
)