|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint
|
|
import math
|
|
from .utils.modules import PatchEmbed, TimestepEmbedder
|
|
from .utils.modules import PE_wrapper, RMSNorm
|
|
from .blocks import DiTBlock, JointDiTBlock, FinalBlock
|
|
|
|
|
|
class UDiT(nn.Module):
|
|
def __init__(self,
|
|
img_size=224, patch_size=16, in_chans=3,
|
|
input_type='2d', out_chans=None,
|
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
|
|
qkv_bias=False, qk_scale=None, qk_norm=None,
|
|
act_layer='gelu', norm_layer='layernorm',
|
|
context_norm=False,
|
|
use_checkpoint=False,
|
|
|
|
time_fusion='token',
|
|
ada_lora_rank=None, ada_lora_alpha=None,
|
|
cls_dim=None,
|
|
|
|
context_dim=768, context_fusion='concat',
|
|
context_max_length=128, context_pe_method='sinu',
|
|
pe_method='abs', rope_mode='none',
|
|
use_conv=True,
|
|
skip=True, skip_norm=True):
|
|
super().__init__()
|
|
self.num_features = self.embed_dim = embed_dim
|
|
|
|
|
|
self.in_chans = in_chans
|
|
self.input_type = input_type
|
|
if self.input_type == '2d':
|
|
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
|
elif self.input_type == '1d':
|
|
num_patches = img_size // patch_size
|
|
self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
|
|
embed_dim=embed_dim, input_type=input_type)
|
|
out_chans = in_chans if out_chans is None else out_chans
|
|
self.out_chans = out_chans
|
|
|
|
|
|
self.rope = rope_mode
|
|
self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
|
|
length=num_patches)
|
|
|
|
print(f'x position embedding: {pe_method}')
|
|
print(f'rope mode: {self.rope}')
|
|
|
|
|
|
self.time_embed = TimestepEmbedder(embed_dim)
|
|
self.time_fusion = time_fusion
|
|
self.use_adanorm = False
|
|
|
|
|
|
if cls_dim is not None:
|
|
self.cls_embed = nn.Sequential(
|
|
nn.Linear(cls_dim, embed_dim, bias=True),
|
|
nn.SiLU(),
|
|
nn.Linear(embed_dim, embed_dim, bias=True),)
|
|
else:
|
|
self.cls_embed = None
|
|
|
|
|
|
if time_fusion == 'token':
|
|
|
|
self.extras = 2 if self.cls_embed else 1
|
|
self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
|
|
elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
|
|
self.use_adanorm = True
|
|
|
|
self.time_act = nn.SiLU()
|
|
self.extras = 0
|
|
self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
|
|
if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
|
|
|
|
self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
|
|
else:
|
|
self.time_ada = None
|
|
else:
|
|
raise NotImplementedError
|
|
print(f'time fusion mode: {self.time_fusion}')
|
|
|
|
|
|
|
|
self.use_context = False
|
|
self.context_cross = False
|
|
self.context_max_length = context_max_length
|
|
self.context_fusion = 'none'
|
|
if context_dim is not None:
|
|
self.use_context = True
|
|
self.context_embed = nn.Sequential(
|
|
nn.Linear(context_dim, embed_dim, bias=True),
|
|
nn.SiLU(),
|
|
nn.Linear(embed_dim, embed_dim, bias=True),)
|
|
self.context_fusion = context_fusion
|
|
if context_fusion == 'concat' or context_fusion == 'joint':
|
|
self.extras += context_max_length
|
|
self.context_pe = PE_wrapper(dim=embed_dim,
|
|
method=context_pe_method,
|
|
length=context_max_length)
|
|
|
|
context_dim = None
|
|
elif context_fusion == 'cross':
|
|
self.context_pe = PE_wrapper(dim=embed_dim,
|
|
method=context_pe_method,
|
|
length=context_max_length)
|
|
self.context_cross = True
|
|
context_dim = embed_dim
|
|
else:
|
|
raise NotImplementedError
|
|
print(f'context fusion mode: {context_fusion}')
|
|
print(f'context position embedding: {context_pe_method}')
|
|
|
|
if self.context_fusion == 'joint':
|
|
Block = JointDiTBlock
|
|
self.use_skip = skip[0]
|
|
else:
|
|
Block = DiTBlock
|
|
self.use_skip = skip
|
|
|
|
|
|
if norm_layer == 'layernorm':
|
|
norm_layer = nn.LayerNorm
|
|
elif norm_layer == 'rmsnorm':
|
|
norm_layer = RMSNorm
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
print(f'use long skip connection: {skip}')
|
|
self.in_blocks = nn.ModuleList([
|
|
Block(
|
|
dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
|
|
act_layer=act_layer, norm_layer=norm_layer,
|
|
time_fusion=time_fusion,
|
|
ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
|
|
skip=False, skip_norm=False,
|
|
rope_mode=self.rope,
|
|
context_norm=context_norm,
|
|
use_checkpoint=use_checkpoint)
|
|
for _ in range(depth // 2)])
|
|
|
|
self.mid_block = Block(
|
|
dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
|
|
act_layer=act_layer, norm_layer=norm_layer,
|
|
time_fusion=time_fusion,
|
|
ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
|
|
skip=False, skip_norm=False,
|
|
rope_mode=self.rope,
|
|
context_norm=context_norm,
|
|
use_checkpoint=use_checkpoint)
|
|
|
|
self.out_blocks = nn.ModuleList([
|
|
Block(
|
|
dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
|
|
act_layer=act_layer, norm_layer=norm_layer,
|
|
time_fusion=time_fusion,
|
|
ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
|
|
skip=skip, skip_norm=skip_norm,
|
|
rope_mode=self.rope,
|
|
context_norm=context_norm,
|
|
use_checkpoint=use_checkpoint)
|
|
for _ in range(depth // 2)])
|
|
|
|
|
|
self.use_conv = use_conv
|
|
self.final_block = FinalBlock(embed_dim=embed_dim,
|
|
patch_size=patch_size,
|
|
img_size=img_size,
|
|
in_chans=out_chans,
|
|
input_type=input_type,
|
|
norm_layer=norm_layer,
|
|
use_conv=use_conv,
|
|
use_adanorm=self.use_adanorm)
|
|
self.initialize_weights()
|
|
|
|
def _init_ada(self):
|
|
if self.time_fusion == 'ada':
|
|
nn.init.constant_(self.time_ada_final.weight, 0)
|
|
nn.init.constant_(self.time_ada_final.bias, 0)
|
|
for block in self.in_blocks:
|
|
nn.init.constant_(block.adaln.time_ada.weight, 0)
|
|
nn.init.constant_(block.adaln.time_ada.bias, 0)
|
|
nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
|
|
nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
|
|
for block in self.out_blocks:
|
|
nn.init.constant_(block.adaln.time_ada.weight, 0)
|
|
nn.init.constant_(block.adaln.time_ada.bias, 0)
|
|
elif self.time_fusion == 'ada_single':
|
|
nn.init.constant_(self.time_ada.weight, 0)
|
|
nn.init.constant_(self.time_ada.bias, 0)
|
|
nn.init.constant_(self.time_ada_final.weight, 0)
|
|
nn.init.constant_(self.time_ada_final.bias, 0)
|
|
elif self.time_fusion in ['ada_lora', 'ada_lora_bias']:
|
|
nn.init.constant_(self.time_ada.weight, 0)
|
|
nn.init.constant_(self.time_ada.bias, 0)
|
|
nn.init.constant_(self.time_ada_final.weight, 0)
|
|
nn.init.constant_(self.time_ada_final.bias, 0)
|
|
for block in self.in_blocks:
|
|
nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
|
|
a=math.sqrt(5))
|
|
nn.init.constant_(block.adaln.lora_b.weight, 0)
|
|
nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight,
|
|
a=math.sqrt(5))
|
|
nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
|
|
for block in self.out_blocks:
|
|
nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
|
|
a=math.sqrt(5))
|
|
nn.init.constant_(block.adaln.lora_b.weight, 0)
|
|
|
|
def initialize_weights(self):
|
|
|
|
def _basic_init(module):
|
|
if isinstance(module, nn.Linear):
|
|
torch.nn.init.xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0)
|
|
self.apply(_basic_init)
|
|
|
|
|
|
w = self.patch_embed.proj.weight.data
|
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
nn.init.constant_(self.patch_embed.proj.bias, 0)
|
|
|
|
|
|
if self.use_adanorm:
|
|
self._init_ada()
|
|
|
|
|
|
if self.context_cross:
|
|
for block in self.in_blocks:
|
|
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
|
nn.init.constant_(block.cross_attn.proj.bias, 0)
|
|
nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
|
|
nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
|
|
for block in self.out_blocks:
|
|
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
|
nn.init.constant_(block.cross_attn.proj.bias, 0)
|
|
|
|
|
|
if self.cls_embed:
|
|
if self.use_adanorm:
|
|
nn.init.constant_(self.cls_embed[-1].weight, 0)
|
|
nn.init.constant_(self.cls_embed[-1].bias, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_conv:
|
|
nn.init.xavier_uniform_(self.final_block.final_layer.weight)
|
|
nn.init.constant_(self.final_block.final_layer.bias, 0)
|
|
|
|
def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
|
|
assert context.shape[-2] == self.context_max_length
|
|
|
|
B = x.shape[0]
|
|
|
|
if x_mask is None:
|
|
x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
|
|
if context_mask is None:
|
|
context_mask = torch.ones(B, context.shape[-2],
|
|
device=context.device).bool()
|
|
|
|
x_mask = torch.cat([context_mask, x_mask], dim=1)
|
|
|
|
x = torch.cat((context, x), dim=1)
|
|
return x, x_mask
|
|
|
|
def forward(self, x, timesteps, context,
|
|
x_mask=None, context_mask=None,
|
|
cls_token=None
|
|
):
|
|
|
|
if timesteps.dim() == 0:
|
|
timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
|
|
|
|
x = self.patch_embed(x)
|
|
x = self.x_pe(x)
|
|
|
|
B, L, D = x.shape
|
|
|
|
if self.use_context:
|
|
context_token = self.context_embed(context)
|
|
context_token = self.context_pe(context_token)
|
|
if self.context_fusion == 'concat' or self.context_fusion == 'joint':
|
|
x, x_mask = self._concat_x_context(x=x, context=context_token,
|
|
x_mask=x_mask,
|
|
context_mask=context_mask)
|
|
context_token, context_mask = None, None
|
|
else:
|
|
context_token, context_mask = None, None
|
|
|
|
time_token = self.time_embed(timesteps)
|
|
if self.cls_embed:
|
|
cls_token = self.cls_embed(cls_token)
|
|
time_ada = None
|
|
time_ada_final = None
|
|
if self.use_adanorm:
|
|
if self.cls_embed:
|
|
time_token = time_token + cls_token
|
|
time_token = self.time_act(time_token)
|
|
time_ada_final = self.time_ada_final(time_token)
|
|
if self.time_ada is not None:
|
|
time_ada = self.time_ada(time_token)
|
|
else:
|
|
time_token = time_token.unsqueeze(dim=1)
|
|
if self.cls_embed:
|
|
cls_token = cls_token.unsqueeze(dim=1)
|
|
time_token = torch.cat([time_token, cls_token], dim=1)
|
|
time_token = self.time_pe(time_token)
|
|
x = torch.cat((time_token, x), dim=1)
|
|
if x_mask is not None:
|
|
x_mask = torch.cat(
|
|
[torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
|
|
x_mask], dim=1)
|
|
time_token = None
|
|
|
|
skips = []
|
|
for blk in self.in_blocks:
|
|
x = blk(x=x, time_token=time_token, time_ada=time_ada,
|
|
skip=None, context=context_token,
|
|
x_mask=x_mask, context_mask=context_mask,
|
|
extras=self.extras)
|
|
if self.use_skip:
|
|
skips.append(x)
|
|
|
|
x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada,
|
|
skip=None, context=context_token,
|
|
x_mask=x_mask, context_mask=context_mask,
|
|
extras=self.extras)
|
|
|
|
for blk in self.out_blocks:
|
|
skip = skips.pop() if self.use_skip else None
|
|
x = blk(x=x, time_token=time_token, time_ada=time_ada,
|
|
skip=skip, context=context_token,
|
|
x_mask=x_mask, context_mask=context_mask,
|
|
extras=self.extras)
|
|
|
|
x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
|
|
|
|
return x |