tvar-demo-test-2 / models /__init__.py
michellemoorre's picture
Initial commit
6c4dee3
raw
history blame
2.23 kB
from typing import Tuple
import torch.nn as nn
from .clip import FrozenCLIPEmbedder
from .quant import VectorQuantizer2
from .var import VAR
from .vqvae import VQVAE
from .pipeline import TVARPipeline
def build_vae_var(
# Shared args
device,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
# VQVAE args
V=4096,
Cvae=32,
ch=160,
share_quant_resi=4,
# VAR args
depth=16,
shared_aln=False,
attn_l2_norm=True,
init_adaln=0.5,
init_adaln_gamma=1e-5,
init_head=0.02,
init_std=-1, # init_std < 0: automated
text_encoder_path=None,
text_encoder_2_path=None,
rope=False,
rope_theta=100,
rope_size=None,
dpr=0,
use_swiglu_ffn=False,
) -> Tuple[VQVAE, VAR]:
heads = depth
width = depth * 64
if dpr > 0:
dpr = dpr * depth / 24
# disable built-in initialization for speed
for clz in (
nn.Linear,
nn.LayerNorm,
nn.BatchNorm2d,
nn.SyncBatchNorm,
nn.Conv1d,
nn.Conv2d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
):
setattr(clz, "reset_parameters", lambda self: None)
# build models
vae_local = VQVAE(
vocab_size=V,
z_channels=Cvae,
ch=ch,
test_mode=True,
share_quant_resi=share_quant_resi,
v_patch_nums=patch_nums,
).to(device)
var_wo_ddp = VAR(
depth=depth,
embed_dim=width,
num_heads=heads,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=dpr,
norm_eps=1e-6,
shared_aln=shared_aln,
attn_l2_norm=attn_l2_norm,
patch_nums=patch_nums,
rope=rope,
rope_theta=rope_theta,
rope_size=rope_size,
use_swiglu_ffn=use_swiglu_ffn,
).to(device)
var_wo_ddp.init_weights(
init_adaln=init_adaln,
init_adaln_gamma=init_adaln_gamma,
init_head=init_head,
init_std=init_std,
)
text_encoder = FrozenCLIPEmbedder(text_encoder_path)
text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path)
pipe = TVARPipeline(var_wo_ddp, vae_local, text_encoder, text_encoder_2, device)
return vae_local, var_wo_ddp, pipe