Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
import torch | |
import torch.nn as nn | |
import math | |
import importlib | |
import craftsman | |
import re | |
from typing import Optional | |
from craftsman.utils.base import BaseModule | |
from craftsman.models.denoisers.utils import * | |
class PixArtDinoDenoiser(BaseModule): | |
class Config(BaseModule.Config): | |
pretrained_model_name_or_path: Optional[str] = None | |
input_channels: int = 32 | |
output_channels: int = 32 | |
n_ctx: int = 512 | |
width: int = 768 | |
layers: int = 28 | |
heads: int = 16 | |
context_dim: int = 1024 | |
n_views: int = 1 | |
context_ln: bool = True | |
skip_ln: bool = False | |
init_scale: float = 0.25 | |
use_checkpoint: bool = False | |
drop_path: float = 0. | |
variance_type: str = "" | |
img_pos_embed: bool = False | |
clip_weight: float = 1.0 | |
dino_weight: float = 1.0 | |
dit_block: str = "" | |
cfg: Config | |
def configure(self) -> None: | |
super().configure() | |
# timestep embedding | |
self.time_embed = TimestepEmbedder(self.cfg.width) | |
# x embedding | |
self.x_embed = nn.Linear(self.cfg.input_channels, self.cfg.width, bias=True) | |
# context embedding | |
if self.cfg.context_ln: | |
self.clip_embed = nn.Sequential( | |
nn.LayerNorm(self.cfg.context_dim), | |
nn.Linear(self.cfg.context_dim, self.cfg.width), | |
) | |
self.dino_embed = nn.Sequential( | |
nn.LayerNorm(self.cfg.context_dim), | |
nn.Linear(self.cfg.context_dim, self.cfg.width), | |
) | |
else: | |
self.clip_embed = nn.Linear(self.cfg.context_dim, self.cfg.width) | |
self.dino_embed = nn.Linear(self.cfg.context_dim, self.cfg.width) | |
init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width) | |
drop_path = [x.item() for x in torch.linspace(0, self.cfg.drop_path, self.cfg.layers)] | |
ditblock = getattr(importlib.import_module("craftsman.models.denoisers.utils"), self.cfg.dit_block) | |
self.blocks = nn.ModuleList([ | |
ditblock( | |
width=self.cfg.width, | |
heads=self.cfg.heads, | |
init_scale=init_scale, | |
qkv_bias=self.cfg.drop_path, | |
use_flash=True, | |
drop_path=drop_path[i] | |
) | |
for i in range(self.cfg.layers) | |
]) | |
self.t_block = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(self.cfg.width, 6 * self.cfg.width, bias=True) | |
) | |
# final layer | |
if self.cfg.variance_type.upper() in ["LEARNED", "LEARNED_RANGE"]: | |
self.output_channels = self.cfg.output_channels * 2 | |
else: | |
self.output_channels = self.cfg.output_channels | |
self.final_layer = T2IFinalLayer(self.cfg.width, self.output_channels) | |
self.identity_initialize() | |
if self.cfg.pretrained_model_name_or_path: | |
print(f"Loading pretrained model from {self.cfg.pretrained_model_name_or_path}") | |
ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu")['state_dict'] | |
self.denoiser_ckpt = {} | |
for k, v in ckpt.items(): | |
if k.startswith('denoiser_model.'): | |
self.denoiser_ckpt[k.replace('denoiser_model.', '')] = v | |
self.load_state_dict(self.denoiser_ckpt, strict=False) | |
def forward_with_dpmsolver(self, model_input, timestep, context): | |
""" | |
dpm solver donnot need variance prediction | |
""" | |
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb | |
model_out = self.forward(model_input, timestep, context) | |
if self.cfg.variance_type.upper() in ["LEARNED", "LEARNED_RANGE"]: | |
return model_out.chunk(2, dim=-1)[0] | |
else: | |
return model_out | |
def identity_initialize(self): | |
for block in self.blocks: | |
nn.init.constant_(block.attn.c_proj.weight, 0) | |
nn.init.constant_(block.attn.c_proj.bias, 0) | |
nn.init.constant_(block.cross_attn.c_proj.weight, 0) | |
nn.init.constant_(block.cross_attn.c_proj.bias, 0) | |
nn.init.constant_(block.mlp.c_proj.weight, 0) | |
nn.init.constant_(block.mlp.c_proj.bias, 0) | |
def forward(self, | |
model_input: torch.FloatTensor, | |
timestep: torch.LongTensor, | |
context: torch.FloatTensor): | |
r""" | |
Args: | |
model_input (torch.FloatTensor): [bs, n_data, c] | |
timestep (torch.LongTensor): [bs,] | |
context (torch.FloatTensor): [bs, context_tokens, c] | |
Returns: | |
sample (torch.FloatTensor): [bs, n_data, c] | |
""" | |
B, n_data, _ = model_input.shape | |
# 1. time | |
t_emb = self.time_embed(timestep) | |
# 2. conditions projector | |
context = context.view(B, self.cfg.n_views, -1, self.cfg.context_dim) | |
clip_feat, dino_feat = context.chunk(2, dim=2) | |
clip_cond = self.clip_embed(clip_feat.contiguous().view(B, -1, self.cfg.context_dim)) | |
dino_cond = self.dino_embed(dino_feat.contiguous().view(B, -1, self.cfg.context_dim)) | |
visual_cond = self.cfg.clip_weight * clip_cond + self.cfg.dino_weight * dino_cond | |
# 4. denoiser | |
latent = self.x_embed(model_input) | |
t0 = self.t_block(t_emb).unsqueeze(dim=1) | |
for block in self.blocks: | |
latent = auto_grad_checkpoint(block, latent, visual_cond, t0) | |
latent = self.final_layer(latent, t_emb) | |
return latent | |