CraftsMan / craftsman /models /denoisers /pixart_denoiser.py
wyysf's picture
update to v1.5
8133633
from dataclasses import dataclass
import torch
import torch.nn as nn
import math
import importlib
import craftsman
import re
from typing import Optional
from craftsman.utils.base import BaseModule
from craftsman.models.denoisers.utils import *
@craftsman.register("pixart-denoiser")
class PixArtDinoDenoiser(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: Optional[str] = None
input_channels: int = 32
output_channels: int = 32
n_ctx: int = 512
width: int = 768
layers: int = 28
heads: int = 16
context_dim: int = 1024
n_views: int = 1
context_ln: bool = True
skip_ln: bool = False
init_scale: float = 0.25
use_checkpoint: bool = False
drop_path: float = 0.
variance_type: str = ""
img_pos_embed: bool = False
clip_weight: float = 1.0
dino_weight: float = 1.0
dit_block: str = ""
cfg: Config
def configure(self) -> None:
super().configure()
# timestep embedding
self.time_embed = TimestepEmbedder(self.cfg.width)
# x embedding
self.x_embed = nn.Linear(self.cfg.input_channels, self.cfg.width, bias=True)
# context embedding
if self.cfg.context_ln:
self.clip_embed = nn.Sequential(
nn.LayerNorm(self.cfg.context_dim),
nn.Linear(self.cfg.context_dim, self.cfg.width),
)
self.dino_embed = nn.Sequential(
nn.LayerNorm(self.cfg.context_dim),
nn.Linear(self.cfg.context_dim, self.cfg.width),
)
else:
self.clip_embed = nn.Linear(self.cfg.context_dim, self.cfg.width)
self.dino_embed = nn.Linear(self.cfg.context_dim, self.cfg.width)
init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width)
drop_path = [x.item() for x in torch.linspace(0, self.cfg.drop_path, self.cfg.layers)]
ditblock = getattr(importlib.import_module("craftsman.models.denoisers.utils"), self.cfg.dit_block)
self.blocks = nn.ModuleList([
ditblock(
width=self.cfg.width,
heads=self.cfg.heads,
init_scale=init_scale,
qkv_bias=self.cfg.drop_path,
use_flash=True,
drop_path=drop_path[i]
)
for i in range(self.cfg.layers)
])
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(self.cfg.width, 6 * self.cfg.width, bias=True)
)
# final layer
if self.cfg.variance_type.upper() in ["LEARNED", "LEARNED_RANGE"]:
self.output_channels = self.cfg.output_channels * 2
else:
self.output_channels = self.cfg.output_channels
self.final_layer = T2IFinalLayer(self.cfg.width, self.output_channels)
self.identity_initialize()
if self.cfg.pretrained_model_name_or_path:
print(f"Loading pretrained model from {self.cfg.pretrained_model_name_or_path}")
ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu")['state_dict']
self.denoiser_ckpt = {}
for k, v in ckpt.items():
if k.startswith('denoiser_model.'):
self.denoiser_ckpt[k.replace('denoiser_model.', '')] = v
self.load_state_dict(self.denoiser_ckpt, strict=False)
def forward_with_dpmsolver(self, model_input, timestep, context):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(model_input, timestep, context)
if self.cfg.variance_type.upper() in ["LEARNED", "LEARNED_RANGE"]:
return model_out.chunk(2, dim=-1)[0]
else:
return model_out
def identity_initialize(self):
for block in self.blocks:
nn.init.constant_(block.attn.c_proj.weight, 0)
nn.init.constant_(block.attn.c_proj.bias, 0)
nn.init.constant_(block.cross_attn.c_proj.weight, 0)
nn.init.constant_(block.cross_attn.c_proj.bias, 0)
nn.init.constant_(block.mlp.c_proj.weight, 0)
nn.init.constant_(block.mlp.c_proj.bias, 0)
def forward(self,
model_input: torch.FloatTensor,
timestep: torch.LongTensor,
context: torch.FloatTensor):
r"""
Args:
model_input (torch.FloatTensor): [bs, n_data, c]
timestep (torch.LongTensor): [bs,]
context (torch.FloatTensor): [bs, context_tokens, c]
Returns:
sample (torch.FloatTensor): [bs, n_data, c]
"""
B, n_data, _ = model_input.shape
# 1. time
t_emb = self.time_embed(timestep)
# 2. conditions projector
context = context.view(B, self.cfg.n_views, -1, self.cfg.context_dim)
clip_feat, dino_feat = context.chunk(2, dim=2)
clip_cond = self.clip_embed(clip_feat.contiguous().view(B, -1, self.cfg.context_dim))
dino_cond = self.dino_embed(dino_feat.contiguous().view(B, -1, self.cfg.context_dim))
visual_cond = self.cfg.clip_weight * clip_cond + self.cfg.dino_weight * dino_cond
# 4. denoiser
latent = self.x_embed(model_input)
t0 = self.t_block(t_emb).unsqueeze(dim=1)
for block in self.blocks:
latent = auto_grad_checkpoint(block, latent, visual_cond, t0)
latent = self.final_layer(latent, t_emb)
return latent