Spaces:
Sleeping
Sleeping
from typing import Tuple | |
import torch.nn as nn | |
from .clip import FrozenCLIPEmbedder | |
from .quant import VectorQuantizer2 | |
from .var import VAR | |
from .vqvae import VQVAE | |
from .pipeline import TVARPipeline | |
def build_vae_var( | |
# Shared args | |
device, | |
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default | |
# VQVAE args | |
V=4096, | |
Cvae=32, | |
ch=160, | |
share_quant_resi=4, | |
# VAR args | |
depth=16, | |
shared_aln=False, | |
attn_l2_norm=True, | |
init_adaln=0.5, | |
init_adaln_gamma=1e-5, | |
init_head=0.02, | |
init_std=-1, # init_std < 0: automated | |
text_encoder_path=None, | |
text_encoder_2_path=None, | |
rope=False, | |
rope_theta=100, | |
rope_size=None, | |
dpr=0, | |
use_swiglu_ffn=False, | |
) -> Tuple[VQVAE, VAR]: | |
heads = depth | |
width = depth * 64 | |
if dpr > 0: | |
dpr = dpr * depth / 24 | |
# disable built-in initialization for speed | |
for clz in ( | |
nn.Linear, | |
nn.LayerNorm, | |
nn.BatchNorm2d, | |
nn.SyncBatchNorm, | |
nn.Conv1d, | |
nn.Conv2d, | |
nn.ConvTranspose1d, | |
nn.ConvTranspose2d, | |
): | |
setattr(clz, "reset_parameters", lambda self: None) | |
# build models | |
vae_local = VQVAE( | |
vocab_size=V, | |
z_channels=Cvae, | |
ch=ch, | |
test_mode=True, | |
share_quant_resi=share_quant_resi, | |
v_patch_nums=patch_nums, | |
).to(device) | |
var_wo_ddp = VAR( | |
depth=depth, | |
embed_dim=width, | |
num_heads=heads, | |
drop_rate=0.0, | |
attn_drop_rate=0.0, | |
drop_path_rate=dpr, | |
norm_eps=1e-6, | |
shared_aln=shared_aln, | |
attn_l2_norm=attn_l2_norm, | |
patch_nums=patch_nums, | |
rope=rope, | |
rope_theta=rope_theta, | |
rope_size=rope_size, | |
use_swiglu_ffn=use_swiglu_ffn, | |
).to(device) | |
var_wo_ddp.init_weights( | |
init_adaln=init_adaln, | |
init_adaln_gamma=init_adaln_gamma, | |
init_head=init_head, | |
init_std=init_std, | |
) | |
text_encoder = FrozenCLIPEmbedder(text_encoder_path) | |
text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path) | |
pipe = TVARPipeline(var_wo_ddp, vae_local, text_encoder, text_encoder_2, device) | |
return vae_local, var_wo_ddp, pipe | |