|
import math |
|
import random |
|
from einops import rearrange |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import numpy as np |
|
from tqdm import trange |
|
|
|
from functools import partial |
|
|
|
from nsr.networks_stylegan2 import Generator as StyleGAN2Backbone |
|
from nsr.volumetric_rendering.renderer import ImportanceRenderer, ImportanceRendererfg_bg |
|
from nsr.volumetric_rendering.ray_sampler import RaySampler |
|
from nsr.triplane import OSGDecoder, Triplane, Triplane_fg_bg_plane |
|
|
|
|
|
from vit.vision_transformer import TriplaneFusionBlockv4_nested, TriplaneFusionBlockv4_nested_init_from_dino_lite, TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, VisionTransformer, TriplaneFusionBlockv4_nested_init_from_dino |
|
|
|
from .vision_transformer import Block, VisionTransformer |
|
from .utils import trunc_normal_ |
|
|
|
from guided_diffusion import dist_util, logger |
|
|
|
from pdb import set_trace as st |
|
|
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder |
|
from utils.torch_utils.components import PixelShuffleUpsample, ResidualBlock, Upsample, PixelUnshuffleUpsample, Conv3x3TriplaneTransformation |
|
from utils.torch_utils.distributions.distributions import DiagonalGaussianDistribution |
|
from nsr.superresolution import SuperresolutionHybrid2X, SuperresolutionHybrid4X |
|
|
|
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer |
|
|
|
from nsr.common_blks import ResMlp |
|
from .vision_transformer import * |
|
|
|
from dit.dit_models import get_2d_sincos_pos_embed |
|
from torch import _assert |
|
from itertools import repeat |
|
import collections.abc |
|
|
|
|
|
|
|
def _ntuple(n): |
|
|
|
def parse(x): |
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
|
return tuple(x) |
|
return tuple(repeat(x, n)) |
|
|
|
return parse |
|
|
|
|
|
to_1tuple = _ntuple(1) |
|
to_2tuple = _ntuple(2) |
|
|
|
|
|
class PatchEmbedTriplane(nn.Module): |
|
""" GroupConv patchembeder on triplane |
|
""" |
|
|
|
def __init__( |
|
self, |
|
img_size=32, |
|
patch_size=2, |
|
in_chans=4, |
|
embed_dim=768, |
|
norm_layer=None, |
|
flatten=True, |
|
bias=True, |
|
): |
|
super().__init__() |
|
img_size = to_2tuple(img_size) |
|
patch_size = to_2tuple(patch_size) |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.grid_size = (img_size[0] // patch_size[0], |
|
img_size[1] // patch_size[1]) |
|
self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
self.flatten = flatten |
|
|
|
self.proj = nn.Conv2d(in_chans, |
|
embed_dim * 3, |
|
kernel_size=patch_size, |
|
stride=patch_size, |
|
bias=bias, |
|
groups=3) |
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
_assert( |
|
H == self.img_size[0], |
|
f"Input image height ({H}) doesn't match model ({self.img_size[0]})." |
|
) |
|
_assert( |
|
W == self.img_size[1], |
|
f"Input image width ({W}) doesn't match model ({self.img_size[1]})." |
|
) |
|
x = self.proj(x) |
|
|
|
x = x.reshape(B, x.shape[1] // 3, 3, x.shape[-2], |
|
x.shape[-1]) |
|
|
|
if self.flatten: |
|
x = x.flatten(2).transpose(1, 2) |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
class PatchEmbedTriplaneRodin(PatchEmbedTriplane): |
|
|
|
def __init__(self, |
|
img_size=32, |
|
patch_size=2, |
|
in_chans=4, |
|
embed_dim=768, |
|
norm_layer=None, |
|
flatten=True, |
|
bias=True): |
|
super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, |
|
flatten, bias) |
|
self.proj = RodinRollOutConv3D_GroupConv(in_chans, |
|
embed_dim * 3, |
|
kernel_size=patch_size, |
|
stride=patch_size, |
|
padding=0) |
|
|
|
|
|
class ViTTriplaneDecomposed(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
vit_decoder, |
|
triplane_decoder: Triplane, |
|
cls_token=False, |
|
decoder_pred_size=-1, |
|
unpatchify_out_chans=-1, |
|
|
|
channel_multiplier=4, |
|
use_fusion_blk=True, |
|
fusion_blk_depth=4, |
|
fusion_blk=TriplaneFusionBlock, |
|
fusion_blk_start=0, |
|
ldm_z_channels=4, |
|
ldm_embed_dim=4, |
|
vae_p=2, |
|
token_size=None, |
|
w_avg=torch.zeros([512]), |
|
patch_size=None, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.superresolution = nn.ModuleDict({}) |
|
|
|
self.decomposed_IN = False |
|
|
|
self.decoder_pred_3d = None |
|
self.transformer_3D_blk = None |
|
self.logvar = None |
|
self.channel_multiplier = channel_multiplier |
|
|
|
self.cls_token = cls_token |
|
self.vit_decoder = vit_decoder |
|
self.triplane_decoder = triplane_decoder |
|
|
|
if patch_size is None: |
|
self.patch_size = self.vit_decoder.patch_embed.patch_size |
|
else: |
|
self.patch_size = patch_size |
|
|
|
if isinstance(self.patch_size, tuple): |
|
self.patch_size = self.patch_size[0] |
|
|
|
|
|
|
|
if unpatchify_out_chans == -1: |
|
self.unpatchify_out_chans = self.triplane_decoder.out_chans |
|
else: |
|
self.unpatchify_out_chans = unpatchify_out_chans |
|
|
|
|
|
if decoder_pred_size == -1: |
|
decoder_pred_size = self.patch_size**2 * self.triplane_decoder.out_chans |
|
|
|
self.decoder_pred = nn.Linear( |
|
self.vit_decoder.embed_dim, |
|
decoder_pred_size, |
|
|
|
|
|
bias=True) |
|
|
|
|
|
|
|
self.plane_n = 3 |
|
|
|
|
|
self.ldm_z_channels = ldm_z_channels |
|
self.ldm_embed_dim = ldm_embed_dim |
|
self.vae_p = vae_p |
|
self.token_size = 16 |
|
self.vae_res = self.vae_p * self.token_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.vit_decoder.pos_embed = nn.Parameter( |
|
torch.zeros(1, 3 * (self.token_size**2 + self.cls_token), |
|
vit_decoder.embed_dim)) |
|
|
|
self.fusion_blk_start = fusion_blk_start |
|
self.create_fusion_blks(fusion_blk_depth, use_fusion_blk, fusion_blk) |
|
|
|
|
|
|
|
|
|
self.register_buffer('w_avg', w_avg) |
|
self.rendering_kwargs = self.triplane_decoder.rendering_kwargs |
|
|
|
|
|
@torch.inference_mode() |
|
def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**16): |
|
|
|
|
|
N, P = points.shape[:2] |
|
if planes.ndim == 4: |
|
planes = planes.reshape( |
|
len(planes), |
|
3, |
|
-1, |
|
planes.shape[-2], |
|
planes.shape[-1]) |
|
|
|
|
|
outs = [] |
|
for i in trange(0, points.shape[1], chunk_size): |
|
chunk_points = points[:, i:i+chunk_size] |
|
|
|
|
|
|
|
chunk_out = self.triplane_decoder.renderer._run_model( |
|
planes=planes, |
|
decoder=self.triplane_decoder.decoder, |
|
sample_coordinates=chunk_points, |
|
sample_directions=torch.zeros_like(chunk_points), |
|
options=self.rendering_kwargs, |
|
) |
|
|
|
|
|
outs.append(chunk_out) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
point_features = { |
|
k: torch.cat([out[k] for out in outs], dim=1) |
|
for k in outs[0].keys() |
|
} |
|
return point_features |
|
|
|
def triplane_decode_grid(self, vit_decode_out, grid_size, aabb: torch.Tensor = None, **kwargs): |
|
|
|
|
|
|
|
assert isinstance(vit_decode_out, dict) |
|
planes = vit_decode_out['latent_after_vit'] |
|
|
|
|
|
if aabb is None: |
|
if 'sampler_bbox_min' in self.rendering_kwargs: |
|
aabb = torch.tensor([ |
|
[self.rendering_kwargs['sampler_bbox_min']] * 3, |
|
[self.rendering_kwargs['sampler_bbox_max']] * 3, |
|
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) |
|
else: |
|
aabb = torch.tensor([ |
|
[-self.rendering_kwargs['box_warp']/2] * 3, |
|
[self.rendering_kwargs['box_warp']/2] * 3, |
|
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) |
|
|
|
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" |
|
N = planes.shape[0] |
|
|
|
|
|
grid_points = [] |
|
for i in range(N): |
|
grid_points.append(torch.stack(torch.meshgrid( |
|
torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), |
|
torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), |
|
torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), |
|
indexing='ij', |
|
), dim=-1).reshape(-1, 3)) |
|
cube_grid = torch.stack(grid_points, dim=0).to(planes.device) |
|
|
|
|
|
features = self.forward_points(planes, cube_grid) |
|
|
|
|
|
features = { |
|
k: v.reshape(N, grid_size, grid_size, grid_size, -1) |
|
for k, v in features.items() |
|
} |
|
|
|
|
|
|
|
return features |
|
|
|
|
|
def create_uvit_arch(self): |
|
|
|
logger.log( |
|
f'length of vit_decoder.blocks: {len(self.vit_decoder.blocks)}') |
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
|
blk.skip_linear = nn.Linear(2 * self.vit_decoder.embed_dim, |
|
self.vit_decoder.embed_dim) |
|
|
|
|
|
nn.init.constant_(blk.skip_linear.weight, 0) |
|
if isinstance(blk.skip_linear, |
|
nn.Linear) and blk.skip_linear.bias is not None: |
|
nn.init.constant_(blk.skip_linear.bias, 0) |
|
|
|
|
|
|
|
|
|
def vit_decode_backbone(self, latent, img_size): |
|
return self.forward_vit_decoder(latent, img_size) |
|
|
|
def init_weights(self): |
|
|
|
p = self.token_size |
|
D = self.vit_decoder.pos_embed.shape[-1] |
|
grid_size = (3 * p, p) |
|
pos_embed = get_2d_sincos_pos_embed(D, |
|
grid_size).reshape(3 * p * p, |
|
D) |
|
self.vit_decoder.pos_embed.data.copy_( |
|
torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
logger.log('init pos_embed with sincos') |
|
|
|
|
|
def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): |
|
|
|
vit_decoder_blks = self.vit_decoder.blocks |
|
assert len(vit_decoder_blks) == 12, 'ViT-B by default' |
|
|
|
nh = self.vit_decoder.blocks[0].attn.num_heads |
|
dim = self.vit_decoder.embed_dim |
|
|
|
fusion_blk_start = self.fusion_blk_start |
|
triplane_fusion_vit_blks = nn.ModuleList() |
|
|
|
if fusion_blk_start != 0: |
|
for i in range(0, fusion_blk_start): |
|
triplane_fusion_vit_blks.append( |
|
vit_decoder_blks[i]) |
|
|
|
for i in range(fusion_blk_start, len(vit_decoder_blks), |
|
fusion_blk_depth): |
|
vit_blks_group = vit_decoder_blks[i:i + |
|
fusion_blk_depth] |
|
triplane_fusion_vit_blks.append( |
|
|
|
fusion_blk(vit_blks_group, nh, dim, use_fusion_blk)) |
|
|
|
self.vit_decoder.blocks = triplane_fusion_vit_blks |
|
|
|
def triplane_decode(self, latent, c): |
|
ret_dict = self.triplane_decoder(latent, c) |
|
ret_dict.update({'latent': latent}) |
|
return ret_dict |
|
|
|
def triplane_renderer(self, latent, coordinates, directions): |
|
|
|
planes = latent.view(len(latent), 3, |
|
self.triplane_decoder.decoder_in_chans, |
|
latent.shape[-2], |
|
latent.shape[-1]) |
|
|
|
ret_dict = self.triplane_decoder.renderer.run_model( |
|
planes, self.triplane_decoder.decoder, coordinates, directions, |
|
self.triplane_decoder.rendering_kwargs) |
|
|
|
return ret_dict |
|
|
|
|
|
|
|
|
|
def unpatchify_triplane(self, x, p=None, unpatchify_out_chans=None): |
|
""" |
|
x: (N, L, patch_size**2 * self.out_chans) |
|
imgs: (N, self.out_chans, H, W) |
|
""" |
|
if unpatchify_out_chans is None: |
|
unpatchify_out_chans = self.unpatchify_out_chans // 3 |
|
|
|
if self.cls_token: |
|
x = x[:, 1:] |
|
|
|
if p is None: |
|
p = self.patch_size |
|
h = w = int((x.shape[1] // 3)**.5) |
|
assert h * w * 3 == x.shape[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], 3, h, w, p, p, unpatchify_out_chans)) |
|
x = torch.einsum('ndhwpqc->ndchpwq', |
|
x) |
|
triplanes = x.reshape(shape=(x.shape[0], unpatchify_out_chans * 3, |
|
h * p, h * p)) |
|
return triplanes |
|
|
|
def interpolate_pos_encoding(self, x, w, h): |
|
previous_dtype = x.dtype |
|
npatch = x.shape[1] - 1 |
|
N = self.vit_decoder.pos_embed.shape[1] - 1 |
|
|
|
|
|
return self.vit_decoder.pos_embed |
|
|
|
|
|
|
|
class_pos_embed = pos_embed[:, 0] |
|
patch_pos_embed = pos_embed[:, 1:] |
|
dim = x.shape[-1] |
|
w0 = w // self.patch_size |
|
h0 = h // self.patch_size |
|
|
|
|
|
w0, h0 = w0 + 0.1, h0 + 0.1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), |
|
dim=1).to(previous_dtype) |
|
|
|
def forward_vit_decoder(self, x, img_size=None): |
|
|
|
|
|
|
|
|
|
if img_size is None: |
|
img_size = self.img_size |
|
|
|
if self.cls_token: |
|
x = x + self.vit_decoder.interpolate_pos_encoding( |
|
x, img_size, img_size)[:, :] |
|
else: |
|
x = x + self.vit_decoder.interpolate_pos_encoding( |
|
x, img_size, img_size)[:, 1:] |
|
|
|
for blk in self.vit_decoder.blocks: |
|
x = blk(x) |
|
x = self.vit_decoder.norm(x) |
|
|
|
return x |
|
|
|
def unpatchify(self, x, p=None, unpatchify_out_chans=None): |
|
""" |
|
x: (N, L, patch_size**2 * self.out_chans) |
|
imgs: (N, self.out_chans, H, W) |
|
""" |
|
|
|
if unpatchify_out_chans is None: |
|
unpatchify_out_chans = self.unpatchify_out_chans |
|
|
|
if self.cls_token: |
|
x = x[:, 1:] |
|
|
|
if p is None: |
|
p = self.patch_size |
|
h = w = int(x.shape[1]**.5) |
|
assert h * w == x.shape[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, unpatchify_out_chans)) |
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], unpatchify_out_chans, h * p, |
|
h * p)) |
|
return imgs |
|
|
|
def forward(self, latent, c, img_size): |
|
latent = self.forward_vit_decoder(latent, img_size) |
|
|
|
if self.cls_token: |
|
|
|
cls_token = latent[:, :1] |
|
else: |
|
cls_token = None |
|
|
|
|
|
latent = self.decoder_pred( |
|
latent) |
|
|
|
latent = self.unpatchify( |
|
latent) |
|
|
|
|
|
|
|
|
|
|
|
ret_dict = self.triplane_decoder(planes=latent, c=c) |
|
ret_dict.update({'latent': latent, 'cls_token': cls_token}) |
|
|
|
return ret_dict |
|
|
|
|
|
class VAE_LDM_V4_vit3D_v3_conv3D_depth2_xformer_mha_PEinit_2d_sincos_uvit_RodinRollOutConv_4x4_lite_mlp_unshuffle_4XC_final( |
|
ViTTriplaneDecomposed): |
|
""" |
|
1. reuse attention proj layer from dino |
|
2. reuse attention; first self then 3D cross attention |
|
""" |
|
""" 4*4 SR with 2X channels |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane, |
|
cls_token, |
|
|
|
|
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
channel_multiplier=4, |
|
fusion_blk=TriplaneFusionBlockv3, |
|
**kwargs) -> None: |
|
super().__init__( |
|
vit_decoder, |
|
triplane_decoder, |
|
cls_token, |
|
|
|
|
|
fusion_blk=fusion_blk, |
|
use_fusion_blk=use_fusion_blk, |
|
fusion_blk_depth=fusion_blk_depth, |
|
channel_multiplier=channel_multiplier, |
|
decoder_pred_size=(4 // 1)**2 * |
|
int(triplane_decoder.out_chans // 3 * channel_multiplier), |
|
**kwargs) |
|
|
|
patch_size = vit_decoder.patch_embed.patch_size |
|
|
|
self.reparameterization_soft_clamp = False |
|
|
|
if isinstance(patch_size, tuple): |
|
patch_size = patch_size[0] |
|
|
|
|
|
unpatchify_out_chans = triplane_decoder.out_chans * 1, |
|
|
|
if unpatchify_out_chans == -1: |
|
unpatchify_out_chans = triplane_decoder.out_chans * 3 |
|
|
|
ldm_z_channels = triplane_decoder.out_chans |
|
|
|
ldm_embed_dim = triplane_decoder.out_chans |
|
ldm_z_channels = ldm_embed_dim = triplane_decoder.out_chans |
|
|
|
self.superresolution.update( |
|
dict( |
|
after_vit_conv=nn.Conv2d( |
|
int(triplane_decoder.out_chans * 2), |
|
triplane_decoder.out_chans * 2, |
|
3, |
|
padding=1), |
|
quant_conv=torch.nn.Conv2d(2 * ldm_z_channels, |
|
2 * ldm_embed_dim, 1), |
|
ldm_downsample=nn.Linear( |
|
384, |
|
|
|
self.vae_p * self.vae_p * 3 * self.ldm_z_channels * |
|
2, |
|
bias=True), |
|
ldm_upsample=nn.Linear(self.vae_p * self.vae_p * |
|
self.ldm_z_channels * 1, |
|
vit_decoder.embed_dim, |
|
bias=True), |
|
quant_mlp=Mlp(2 * self.ldm_z_channels, |
|
out_features=2 * self.ldm_embed_dim), |
|
conv_sr=RodinConv3D4X_lite_mlp_as_residual( |
|
int(triplane_decoder.out_chans * channel_multiplier), |
|
int(triplane_decoder.out_chans * 1)))) |
|
|
|
has_token = bool(self.cls_token) |
|
self.vit_decoder.pos_embed = nn.Parameter( |
|
torch.zeros(1, 3 * 16 * 16 + has_token, vit_decoder.embed_dim)) |
|
|
|
self.init_weights() |
|
self.reparameterization_soft_clamp = True |
|
|
|
self.create_uvit_arch() |
|
|
|
def vae_reparameterization(self, latent, sample_posterior): |
|
"""input: latent from ViT encoder |
|
""" |
|
|
|
latents3D = self.superresolution['ldm_downsample'](latent) |
|
|
|
if self.vae_p > 1: |
|
latents3D = self.unpatchify3D( |
|
latents3D, |
|
p=self.vae_p, |
|
unpatchify_out_chans=self.ldm_z_channels * |
|
2) |
|
latents3D = latents3D.reshape( |
|
latents3D.shape[0], 3, -1, latents3D.shape[-1] |
|
) |
|
else: |
|
latents3D = latents3D.reshape(latents3D.shape[0], |
|
latents3D.shape[1], 3, |
|
2 * self.ldm_z_channels) |
|
latents3D = latents3D.permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
|
|
|
|
posterior = self.vae_encode(latents3D) |
|
|
|
if sample_posterior: |
|
latent = posterior.sample() |
|
else: |
|
latent = posterior.mode() |
|
|
|
log_q = posterior.log_p(latent) |
|
|
|
|
|
|
|
|
|
|
|
latent_normalized_2Ddiffusion = latent.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
log_q_2Ddiffusion = log_q.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
latent = latent.permute(0, 2, 3, 1) |
|
|
|
latent = latent.reshape(latent.shape[0], -1, |
|
latent.shape[-1]) |
|
|
|
ret_dict = dict( |
|
normal_entropy=posterior.normal_entropy(), |
|
latent_normalized=latent, |
|
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
|
log_q_2Ddiffusion=log_q_2Ddiffusion, |
|
log_q=log_q, |
|
posterior=posterior, |
|
latent_name= |
|
'latent_normalized' |
|
) |
|
|
|
return ret_dict |
|
|
|
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
|
if self.cls_token: |
|
cls_token = latent_from_vit[:, :1] |
|
else: |
|
cls_token = None |
|
|
|
|
|
latent = self.decoder_pred( |
|
latent_from_vit |
|
) |
|
|
|
latent = self.unpatchify_triplane( |
|
latent, |
|
p=4, |
|
unpatchify_out_chans=int( |
|
self.channel_multiplier * self.unpatchify_out_chans // |
|
3)) |
|
|
|
|
|
latent = self.superresolution['conv_sr'](latent) |
|
|
|
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
|
|
|
sr_w_code = self.w_avg |
|
assert sr_w_code is not None |
|
ret_dict.update( |
|
dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( |
|
latent_from_vit.shape[0], 0), )) |
|
|
|
return ret_dict |
|
|
|
def forward_vit_decoder(self, x, img_size=None): |
|
|
|
|
|
|
|
|
|
if img_size is None: |
|
img_size = self.img_size |
|
|
|
|
|
|
|
x = x + self.interpolate_pos_encoding(x, img_size, |
|
img_size)[:, :] |
|
|
|
B, L, C = x.shape |
|
x = x.view(B, 3, L // 3, C) |
|
|
|
skips = [x] |
|
assert self.fusion_blk_start == 0 |
|
|
|
|
|
for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // |
|
2 - 1]: |
|
x = blk(x) |
|
skips.append(x) |
|
|
|
|
|
|
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - |
|
1:len(self.vit_decoder.blocks) // |
|
2]: |
|
x = blk(x) |
|
|
|
|
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
|
x = x + blk.skip_linear(torch.cat([x, skips.pop()], |
|
dim=-1)) |
|
x = blk(x) |
|
|
|
x = self.vit_decoder.norm(x) |
|
|
|
|
|
x = x.view(B, L, C) |
|
return x |
|
|
|
def triplane_decode(self, |
|
vit_decode_out, |
|
c, |
|
return_raw_only=False, |
|
**kwargs): |
|
|
|
if isinstance(vit_decode_out, dict): |
|
latent_after_vit, sr_w_code = (vit_decode_out.get(k, None) |
|
for k in ('latent_after_vit', |
|
'sr_w_code')) |
|
|
|
else: |
|
latent_after_vit = vit_decode_out |
|
sr_w_code = None |
|
vit_decode_out = dict(latent_normalized=latent_after_vit |
|
) |
|
|
|
|
|
ret_dict = self.triplane_decoder(latent_after_vit, |
|
c, |
|
ws=sr_w_code, |
|
return_raw_only=return_raw_only, |
|
**kwargs) |
|
ret_dict.update({ |
|
'latent_after_vit': latent_after_vit, |
|
**vit_decode_out |
|
}) |
|
|
|
return ret_dict |
|
|
|
def vit_decode_backbone(self, latent, img_size): |
|
|
|
if isinstance(latent, dict): |
|
if 'latent_normalized' not in latent: |
|
latent = latent[ |
|
'latent_normalized_2Ddiffusion'] |
|
else: |
|
latent = latent[ |
|
'latent_normalized'] |
|
|
|
|
|
if latent.ndim != 3: |
|
latent = latent.reshape(latent.shape[0], latent.shape[1] // 3, 3, |
|
(self.vae_p * self.token_size)**2).permute( |
|
0, 2, 3, 1) |
|
latent = latent.reshape(latent.shape[0], -1, |
|
latent.shape[-1]) |
|
|
|
assert latent.shape == ( |
|
|
|
latent.shape[0], |
|
3 * ((self.vae_p * self.token_size)**2), |
|
self.ldm_z_channels), f'latent.shape: {latent.shape}' |
|
|
|
latent = self.superresolution['ldm_upsample'](latent) |
|
|
|
return super().vit_decode_backbone( |
|
latent, img_size) |
|
|
|
|
|
class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn( |
|
ViTTriplaneDecomposed): |
|
|
|
def __init__( |
|
self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane_fg_bg_plane, |
|
cls_token, |
|
|
|
|
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, |
|
channel_multiplier=4, |
|
ldm_z_channels=4, |
|
ldm_embed_dim=4, |
|
vae_p=2, |
|
**kwargs) -> None: |
|
|
|
super().__init__( |
|
vit_decoder, |
|
triplane_decoder, |
|
cls_token, |
|
|
|
channel_multiplier=channel_multiplier, |
|
use_fusion_blk=use_fusion_blk, |
|
fusion_blk_depth=fusion_blk_depth, |
|
fusion_blk=fusion_blk, |
|
ldm_z_channels=ldm_z_channels, |
|
ldm_embed_dim=ldm_embed_dim, |
|
vae_p=vae_p, |
|
decoder_pred_size=(4 // 1)**2 * |
|
int(triplane_decoder.out_chans // 3 * channel_multiplier), |
|
**kwargs) |
|
|
|
logger.log( |
|
f'length of vit_decoder.blocks: {len(self.vit_decoder.blocks)}') |
|
|
|
|
|
self.superresolution.update( |
|
dict( |
|
ldm_downsample=nn.Linear( |
|
384, |
|
self.vae_p * self.vae_p * 3 * self.ldm_z_channels * |
|
2, |
|
bias=True), |
|
ldm_upsample=PatchEmbedTriplane( |
|
self.vae_p * self.token_size, |
|
self.vae_p, |
|
3 * self.ldm_embed_dim, |
|
vit_decoder.embed_dim, |
|
bias=True), |
|
quant_conv=nn.Conv2d(2 * 3 * self.ldm_z_channels, |
|
2 * self.ldm_embed_dim * 3, |
|
kernel_size=1, |
|
groups=3), |
|
conv_sr=RodinConv3D4X_lite_mlp_as_residual_lite( |
|
int(triplane_decoder.out_chans * channel_multiplier), |
|
int(triplane_decoder.out_chans * 1)))) |
|
|
|
|
|
self.init_weights() |
|
self.reparameterization_soft_clamp = True |
|
|
|
self.create_uvit_arch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vit_decode(self, latent, img_size, sample_posterior=True): |
|
|
|
ret_dict = self.vae_reparameterization(latent, sample_posterior) |
|
|
|
|
|
latent = self.vit_decode_backbone(ret_dict, img_size) |
|
return self.vit_decode_postprocess(latent, ret_dict) |
|
|
|
|
|
def unpatchify3D(self, x, p, unpatchify_out_chans, plane_n=3): |
|
""" |
|
x: (N, L, patch_size**2 * self.out_chans) |
|
return: 3D latents |
|
""" |
|
|
|
if self.cls_token: |
|
x = x[:, 1:] |
|
|
|
h = w = int(x.shape[1]**.5) |
|
assert h * w == x.shape[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, plane_n, |
|
unpatchify_out_chans)) |
|
|
|
x = torch.einsum( |
|
'nhwpqdc->ndhpwqc', x |
|
) |
|
|
|
latents3D = x.reshape(shape=(x.shape[0], plane_n, h * p, h * p, |
|
unpatchify_out_chans)) |
|
return latents3D |
|
|
|
|
|
def vae_encode(self, h): |
|
|
|
|
|
|
|
B, _, H, W = h.shape |
|
moments = self.superresolution['quant_conv'](h) |
|
|
|
moments = moments.reshape( |
|
B, |
|
|
|
moments.shape[1] // self.plane_n, |
|
|
|
self.plane_n, |
|
H, |
|
W, |
|
) |
|
|
|
moments = moments.flatten(-2) |
|
|
|
posterior = DiagonalGaussianDistribution( |
|
moments, soft_clamp=self.reparameterization_soft_clamp) |
|
return posterior |
|
|
|
def vae_reparameterization(self, latent, sample_posterior): |
|
"""input: latent from ViT encoder |
|
""" |
|
|
|
|
|
latents3D = self.superresolution['ldm_downsample']( |
|
latent) |
|
|
|
assert self.vae_p > 1 |
|
latents3D = self.unpatchify3D( |
|
latents3D, |
|
p=self.vae_p, |
|
unpatchify_out_chans=self.ldm_z_channels * |
|
2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B, _, H, W, C = latents3D.shape |
|
latents3D = latents3D.permute(0, 1, 4, 2, 3).reshape(B, -1, H, |
|
W) |
|
|
|
|
|
posterior = self.vae_encode(latents3D) |
|
|
|
if sample_posterior: |
|
latent = posterior.sample() |
|
else: |
|
latent = posterior.mode() |
|
|
|
log_q = posterior.log_p(latent) |
|
|
|
|
|
latent_normalized_2Ddiffusion = latent.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
log_q_2Ddiffusion = log_q.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
|
|
latent = latent.permute(0, 2, 3, 1) |
|
|
|
latent = latent.reshape(latent.shape[0], -1, |
|
latent.shape[-1]) |
|
|
|
ret_dict = dict( |
|
normal_entropy=posterior.normal_entropy(), |
|
latent_normalized=latent, |
|
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
|
log_q_2Ddiffusion=log_q_2Ddiffusion, |
|
log_q=log_q, |
|
posterior=posterior, |
|
) |
|
|
|
return ret_dict |
|
|
|
def vit_decode_backbone(self, latent, img_size): |
|
|
|
if isinstance(latent, dict): |
|
latent = latent['latent_normalized_2Ddiffusion'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
latent = self.superresolution['ldm_upsample']( |
|
latent) |
|
|
|
|
|
|
|
return self.forward_vit_decoder(latent, img_size) |
|
|
|
def triplane_decode(self, |
|
vit_decode_out, |
|
c, |
|
return_raw_only=False, |
|
**kwargs): |
|
|
|
if isinstance(vit_decode_out, dict): |
|
latent_after_vit, sr_w_code = (vit_decode_out.get(k, None) |
|
for k in ('latent_after_vit', |
|
'sr_w_code')) |
|
|
|
else: |
|
latent_after_vit = vit_decode_out |
|
sr_w_code = None |
|
vit_decode_out = dict(latent_normalized=latent_after_vit |
|
) |
|
|
|
|
|
ret_dict = self.triplane_decoder(latent_after_vit, |
|
c, |
|
ws=sr_w_code, |
|
return_raw_only=return_raw_only, |
|
**kwargs) |
|
ret_dict.update({ |
|
'latent_after_vit': latent_after_vit, |
|
**vit_decode_out |
|
}) |
|
|
|
return ret_dict |
|
|
|
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
|
if self.cls_token: |
|
cls_token = latent_from_vit[:, :1] |
|
else: |
|
cls_token = None |
|
|
|
|
|
latent = self.decoder_pred( |
|
latent_from_vit |
|
) |
|
|
|
latent = self.unpatchify_triplane( |
|
latent, |
|
p=4, |
|
unpatchify_out_chans=int( |
|
self.channel_multiplier * self.unpatchify_out_chans // |
|
3)) |
|
|
|
|
|
latent = self.superresolution['conv_sr'](latent) |
|
|
|
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
|
|
|
sr_w_code = self.w_avg |
|
assert sr_w_code is not None |
|
ret_dict.update( |
|
dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( |
|
latent_from_vit.shape[0], 0), )) |
|
|
|
return ret_dict |
|
|
|
def forward_vit_decoder(self, x, img_size=None): |
|
|
|
|
|
|
|
|
|
if img_size is None: |
|
img_size = self.img_size |
|
|
|
|
|
|
|
x = x + self.interpolate_pos_encoding(x, img_size, |
|
img_size)[:, :] |
|
|
|
B, L, C = x.shape |
|
x = x.view(B, 3, L // 3, C) |
|
|
|
skips = [x] |
|
assert self.fusion_blk_start == 0 |
|
|
|
|
|
for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // |
|
2 - 1]: |
|
x = blk(x) |
|
skips.append(x) |
|
|
|
|
|
|
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - |
|
1:len(self.vit_decoder.blocks) // |
|
2]: |
|
x = blk(x) |
|
|
|
|
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
|
x = x + blk.skip_linear(torch.cat([x, skips.pop()], |
|
dim=-1)) |
|
x = blk(x) |
|
|
|
x = self.vit_decoder.norm(x) |
|
|
|
|
|
x = x.view(B, L, C) |
|
return x |
|
|
|
|
|
|
|
class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD( |
|
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn): |
|
|
|
def __init__(self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane_fg_bg_plane, |
|
cls_token, |
|
normalize_feat=True, |
|
sr_ratio=2, |
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, |
|
channel_multiplier=4, |
|
**kwargs) -> None: |
|
super().__init__(vit_decoder, |
|
triplane_decoder, |
|
cls_token, |
|
|
|
use_fusion_blk=use_fusion_blk, |
|
fusion_blk_depth=fusion_blk_depth, |
|
fusion_blk=fusion_blk, |
|
channel_multiplier=channel_multiplier, |
|
**kwargs) |
|
|
|
for k in [ |
|
'ldm_downsample', |
|
|
|
]: |
|
del self.superresolution[k] |
|
|
|
def vae_reparameterization(self, latent, sample_posterior): |
|
|
|
|
|
assert self.vae_p > 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
posterior = self.vae_encode(latent) |
|
|
|
if sample_posterior: |
|
latent = posterior.sample() |
|
else: |
|
latent = posterior.mode() |
|
|
|
log_q = posterior.log_p(latent) |
|
|
|
|
|
latent_normalized_2Ddiffusion = latent.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
log_q_2Ddiffusion = log_q.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
|
|
latent = latent.permute(0, 2, 3, 1) |
|
|
|
latent = latent.reshape(latent.shape[0], -1, |
|
latent.shape[-1]) |
|
|
|
ret_dict = dict( |
|
normal_entropy=posterior.normal_entropy(), |
|
latent_normalized=latent, |
|
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
|
log_q_2Ddiffusion=log_q_2Ddiffusion, |
|
log_q=log_q, |
|
posterior=posterior, |
|
) |
|
|
|
return ret_dict |
|
|
|
|
|
class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD_D( |
|
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): |
|
|
|
def __init__(self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane_fg_bg_plane, |
|
cls_token, |
|
normalize_feat=True, |
|
sr_ratio=2, |
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, |
|
channel_multiplier=4, |
|
**kwargs) -> None: |
|
super().__init__(vit_decoder, triplane_decoder, cls_token, |
|
normalize_feat, sr_ratio, use_fusion_blk, |
|
fusion_blk_depth, fusion_blk, channel_multiplier, |
|
**kwargs) |
|
|
|
self.decoder_pred = None |
|
|
|
self.superresolution.update( |
|
dict(conv_sr=Decoder( |
|
resolution=128, |
|
in_channels=3, |
|
|
|
ch=32, |
|
ch_mult=[1, 2, 2, 4], |
|
|
|
|
|
num_res_blocks=1, |
|
dropout=0.0, |
|
attn_resolutions=[], |
|
out_ch=32, |
|
|
|
z_channels=vit_decoder.embed_dim, |
|
))) |
|
|
|
|
|
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
|
if self.cls_token: |
|
cls_token = latent_from_vit[:, :1] |
|
else: |
|
cls_token = None |
|
|
|
def unflatten_token(x, p=None): |
|
B, L, C = x.shape |
|
x = x.reshape(B, 3, L // 3, C) |
|
|
|
if self.cls_token: |
|
x = x[:, :, 1:] |
|
|
|
h = w = int((x.shape[2])**.5) |
|
assert h * w == x.shape[2] |
|
|
|
if p is None: |
|
x = x.reshape(shape=(B, 3, h, w, -1)) |
|
x = rearrange( |
|
x, 'b n h w c->(b n) c h w' |
|
) |
|
else: |
|
x = x.reshape(shape=(B, 3, h, w, p, p, -1)) |
|
x = rearrange( |
|
x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' |
|
) |
|
|
|
return x |
|
|
|
latent = unflatten_token(latent_from_vit) |
|
|
|
|
|
|
|
|
|
latent = self.superresolution['conv_sr'](latent) |
|
latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) |
|
|
|
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ret_dict |
|
|
|
|
|
|
|
|
|
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_lite3DAttn( |
|
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): |
|
|
|
def __init__(self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane_fg_bg_plane, |
|
cls_token, |
|
normalize_feat=True, |
|
sr_ratio=2, |
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite, |
|
channel_multiplier=4, |
|
**kwargs) -> None: |
|
super().__init__(vit_decoder, triplane_decoder, cls_token, |
|
normalize_feat, sr_ratio, use_fusion_blk, |
|
fusion_blk_depth, fusion_blk, channel_multiplier, |
|
**kwargs) |
|
|
|
|
|
|
|
|
|
self.decoder_pred = nn.Linear(self.vit_decoder.embed_dim // 3, |
|
2048, |
|
bias=True) |
|
|
|
|
|
self.superresolution.update( |
|
dict(ldm_upsample=PatchEmbedTriplaneRodin( |
|
self.vae_p * self.token_size, |
|
self.vae_p, |
|
3 * self.ldm_embed_dim, |
|
vit_decoder.embed_dim // 3, |
|
bias=True))) |
|
|
|
|
|
has_token = bool(self.cls_token) |
|
self.vit_decoder.pos_embed = nn.Parameter( |
|
torch.zeros(1, 16 * 16 + has_token, vit_decoder.embed_dim)) |
|
|
|
def forward(self, latent, c, img_size): |
|
|
|
latent_normalized = self.vit_decode(latent, img_size) |
|
return self.triplane_decode(latent_normalized, c) |
|
|
|
def vae_reparameterization(self, latent, sample_posterior): |
|
|
|
|
|
assert self.vae_p > 1 |
|
|
|
|
|
|
|
posterior = self.vae_encode(latent) |
|
|
|
if sample_posterior: |
|
latent = posterior.sample() |
|
else: |
|
latent = posterior.mode() |
|
|
|
log_q = posterior.log_p(latent) |
|
|
|
|
|
latent_normalized_2Ddiffusion = latent.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
log_q_2Ddiffusion = log_q.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
|
|
|
|
|
|
|
|
latent = latent.permute(0, 3, 1, 2) |
|
latent = latent.reshape(*latent.shape[:2], -1) |
|
|
|
ret_dict = dict( |
|
normal_entropy=posterior.normal_entropy(), |
|
latent_normalized=latent, |
|
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
|
log_q_2Ddiffusion=log_q_2Ddiffusion, |
|
log_q=log_q, |
|
posterior=posterior, |
|
) |
|
|
|
return ret_dict |
|
|
|
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
|
if self.cls_token: |
|
cls_token = latent_from_vit[:, :1] |
|
else: |
|
cls_token = None |
|
|
|
B, N, C = latent_from_vit.shape |
|
latent_from_vit = latent_from_vit.reshape(B, N, C // 3, 3).permute( |
|
0, 3, 1, 2) |
|
|
|
|
|
|
|
|
|
latent = self.decoder_pred( |
|
latent_from_vit |
|
) |
|
|
|
latent = latent.reshape(B, 3 * N, -1) |
|
|
|
latent = self.unpatchify_triplane( |
|
latent, |
|
p=4, |
|
unpatchify_out_chans=int( |
|
self.channel_multiplier * self.unpatchify_out_chans // |
|
3)) |
|
|
|
|
|
latent = self.superresolution['conv_sr'](latent) |
|
|
|
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
|
|
|
sr_w_code = self.w_avg |
|
assert sr_w_code is not None |
|
ret_dict.update( |
|
dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( |
|
latent_from_vit.shape[0], 0), )) |
|
|
|
return ret_dict |
|
|
|
def vit_decode_backbone(self, latent, img_size): |
|
|
|
if isinstance(latent, dict): |
|
latent = latent['latent_normalized_2Ddiffusion'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
latent = self.superresolution['ldm_upsample']( |
|
latent) |
|
|
|
|
|
B, N3, C = latent.shape |
|
latent = latent.reshape(B, 3, N3 // 3, |
|
C).permute(0, 2, 3, 1) |
|
latent = latent.reshape(*latent.shape[:2], -1) |
|
|
|
|
|
return self.forward_vit_decoder(latent, img_size) |
|
|
|
def forward_vit_decoder(self, x, img_size=None): |
|
|
|
|
|
|
|
|
|
if img_size is None: |
|
img_size = self.img_size |
|
|
|
|
|
x = x + self.interpolate_pos_encoding(x, img_size, |
|
img_size)[:, :] |
|
|
|
B, L, C = x.shape |
|
|
|
|
|
|
|
|
|
skips = [x] |
|
assert self.fusion_blk_start == 0 |
|
|
|
|
|
for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // |
|
2 - 1]: |
|
x = blk(x) |
|
skips.append(x) |
|
|
|
|
|
|
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - |
|
1:len(self.vit_decoder.blocks) // |
|
2]: |
|
x = blk(x) |
|
|
|
|
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
|
x = x + blk.skip_linear(torch.cat([x, skips.pop()], |
|
dim=-1)) |
|
x = blk(x) |
|
|
|
x = self.vit_decoder.norm(x) |
|
|
|
|
|
x = x.view(B, L, C) |
|
return x |
|
|
|
def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): |
|
|
|
vit_decoder_blks = self.vit_decoder.blocks |
|
assert len(vit_decoder_blks) == 12, 'ViT-B by default' |
|
|
|
nh = self.vit_decoder.blocks[ |
|
0].attn.num_heads // 3 |
|
dim = self.vit_decoder.embed_dim // 3 |
|
|
|
fusion_blk_start = self.fusion_blk_start |
|
triplane_fusion_vit_blks = nn.ModuleList() |
|
|
|
if fusion_blk_start != 0: |
|
for i in range(0, fusion_blk_start): |
|
triplane_fusion_vit_blks.append( |
|
vit_decoder_blks[i]) |
|
|
|
for i in range(fusion_blk_start, len(vit_decoder_blks), |
|
fusion_blk_depth): |
|
vit_blks_group = vit_decoder_blks[i:i + |
|
fusion_blk_depth] |
|
triplane_fusion_vit_blks.append( |
|
|
|
fusion_blk(vit_blks_group, nh, dim, use_fusion_blk)) |
|
|
|
self.vit_decoder.blocks = triplane_fusion_vit_blks |
|
|
|
|
|
|
|
|
|
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder_S( |
|
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn): |
|
|
|
def __init__( |
|
self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane_fg_bg_plane, |
|
cls_token, |
|
normalize_feat=True, |
|
sr_ratio=2, |
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, |
|
channel_multiplier=4, |
|
**kwargs) -> None: |
|
super().__init__( |
|
vit_decoder, |
|
triplane_decoder, |
|
cls_token, |
|
use_fusion_blk=use_fusion_blk, |
|
fusion_blk_depth=fusion_blk_depth, |
|
fusion_blk=fusion_blk, |
|
channel_multiplier=channel_multiplier, |
|
patch_size=-1, |
|
token_size=2, |
|
**kwargs) |
|
self.D_roll_out_input = False |
|
|
|
for k in [ |
|
'ldm_downsample', |
|
|
|
]: |
|
del self.superresolution[k] |
|
|
|
self.decoder_pred = None |
|
self.superresolution.update( |
|
dict( |
|
conv_sr=Decoder( |
|
resolution=128, |
|
|
|
in_channels=3, |
|
|
|
ch=32, |
|
|
|
ch_mult=[1, 2, 2, 4], |
|
|
|
|
|
|
|
|
|
num_res_blocks=1, |
|
dropout=0.0, |
|
attn_resolutions=[], |
|
out_ch=32, |
|
|
|
z_channels=vit_decoder.embed_dim, |
|
|
|
), |
|
|
|
)) |
|
|
|
|
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
|
del blk.skip_linear |
|
|
|
@torch.inference_mode() |
|
def forward_points(self, |
|
planes, |
|
points: torch.Tensor, |
|
chunk_size: int = 2**16): |
|
|
|
|
|
N, P = points.shape[:2] |
|
if planes.ndim == 4: |
|
planes = planes.reshape( |
|
len(planes), |
|
3, |
|
-1, |
|
planes.shape[-2], |
|
planes.shape[-1]) |
|
|
|
|
|
outs = [] |
|
for i in trange(0, points.shape[1], chunk_size): |
|
chunk_points = points[:, i:i + chunk_size] |
|
|
|
|
|
|
|
chunk_out = self.triplane_decoder.renderer._run_model( |
|
planes=planes, |
|
decoder=self.triplane_decoder.decoder, |
|
sample_coordinates=chunk_points, |
|
sample_directions=torch.zeros_like(chunk_points), |
|
options=self.rendering_kwargs, |
|
) |
|
|
|
|
|
outs.append(chunk_out) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
point_features = { |
|
k: torch.cat([out[k] for out in outs], dim=1) |
|
for k in outs[0].keys() |
|
} |
|
return point_features |
|
|
|
def triplane_decode_grid(self, |
|
vit_decode_out, |
|
grid_size, |
|
aabb: torch.Tensor = None, |
|
**kwargs): |
|
|
|
|
|
|
|
assert isinstance(vit_decode_out, dict) |
|
planes = vit_decode_out['latent_after_vit'] |
|
|
|
|
|
if aabb is None: |
|
if 'sampler_bbox_min' in self.rendering_kwargs: |
|
aabb = torch.tensor([ |
|
[self.rendering_kwargs['sampler_bbox_min']] * 3, |
|
[self.rendering_kwargs['sampler_bbox_max']] * 3, |
|
], |
|
device=planes.device, |
|
dtype=planes.dtype).unsqueeze(0).repeat( |
|
planes.shape[0], 1, 1) |
|
else: |
|
aabb = torch.tensor( |
|
[ |
|
[-self.rendering_kwargs['box_warp'] / 2] * 3, |
|
[self.rendering_kwargs['box_warp'] / 2] * 3, |
|
], |
|
device=planes.device, |
|
dtype=planes.dtype).unsqueeze(0).repeat( |
|
planes.shape[0], 1, 1) |
|
|
|
assert planes.shape[0] == aabb.shape[ |
|
0], "Batch size mismatch for planes and aabb" |
|
N = planes.shape[0] |
|
|
|
|
|
grid_points = [] |
|
for i in range(N): |
|
grid_points.append( |
|
torch.stack(torch.meshgrid( |
|
torch.linspace(aabb[i, 0, 0], |
|
aabb[i, 1, 0], |
|
grid_size, |
|
device=planes.device), |
|
torch.linspace(aabb[i, 0, 1], |
|
aabb[i, 1, 1], |
|
grid_size, |
|
device=planes.device), |
|
torch.linspace(aabb[i, 0, 2], |
|
aabb[i, 1, 2], |
|
grid_size, |
|
device=planes.device), |
|
indexing='ij', |
|
), |
|
dim=-1).reshape(-1, 3)) |
|
cube_grid = torch.stack(grid_points, dim=0).to(planes.device) |
|
|
|
|
|
features = self.forward_points(planes, cube_grid) |
|
|
|
|
|
features = { |
|
k: v.reshape(N, grid_size, grid_size, grid_size, -1) |
|
for k, v in features.items() |
|
} |
|
|
|
|
|
|
|
return features |
|
|
|
def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): |
|
|
|
pass |
|
|
|
def forward_vit_decoder(self, x, img_size=None): |
|
|
|
return self.vit_decoder(x) |
|
|
|
def vit_decode_backbone(self, latent, img_size): |
|
|
|
if isinstance(latent, dict): |
|
latent = latent['latent_normalized_2Ddiffusion'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent = self.superresolution['ldm_upsample']( |
|
latent) |
|
|
|
|
|
|
|
return self.forward_vit_decoder(latent, img_size) |
|
|
|
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
|
if self.cls_token: |
|
cls_token = latent_from_vit[:, :1] |
|
else: |
|
cls_token = None |
|
|
|
def unflatten_token(x, p=None): |
|
B, L, C = x.shape |
|
x = x.reshape(B, 3, L // 3, C) |
|
|
|
if self.cls_token: |
|
x = x[:, :, 1:] |
|
|
|
h = w = int((x.shape[2])**.5) |
|
assert h * w == x.shape[2] |
|
|
|
if p is None: |
|
x = x.reshape(shape=(B, 3, h, w, -1)) |
|
if not self.D_roll_out_input: |
|
x = rearrange( |
|
x, 'b n h w c->(b n) c h w' |
|
) |
|
else: |
|
x = rearrange( |
|
x, 'b n h w c->b c h (n w)' |
|
) |
|
else: |
|
x = x.reshape(shape=(B, 3, h, w, p, p, -1)) |
|
if self.D_roll_out_input: |
|
x = rearrange( |
|
x, 'b n h w p1 p2 c->b c (h p1) (n w p2)' |
|
) |
|
else: |
|
x = rearrange( |
|
x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' |
|
) |
|
|
|
return x |
|
|
|
latent = unflatten_token( |
|
latent_from_vit) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent = self.superresolution['conv_sr'](latent) |
|
if not self.D_roll_out_input: |
|
latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) |
|
else: |
|
latent = rearrange(latent, 'b c h (n w)->b (n c) h w', n=3) |
|
|
|
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ret_dict |
|
|
|
def vae_reparameterization(self, latent, sample_posterior): |
|
|
|
|
|
assert self.vae_p > 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
posterior = self.vae_encode(latent) |
|
|
|
if sample_posterior: |
|
latent = posterior.sample() |
|
else: |
|
latent = posterior.mode() |
|
|
|
log_q = posterior.log_p(latent) |
|
|
|
|
|
latent_normalized_2Ddiffusion = latent.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
log_q_2Ddiffusion = log_q.reshape( |
|
latent.shape[0], -1, self.token_size * self.vae_p, |
|
self.token_size * self.vae_p) |
|
|
|
|
|
latent = latent.permute(0, 2, 3, 1) |
|
|
|
latent = latent.reshape(latent.shape[0], -1, |
|
latent.shape[-1]) |
|
|
|
ret_dict = dict( |
|
normal_entropy=posterior.normal_entropy(), |
|
latent_normalized=latent, |
|
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
|
log_q_2Ddiffusion=log_q_2Ddiffusion, |
|
log_q=log_q, |
|
posterior=posterior, |
|
) |
|
|
|
return ret_dict |
|
|
|
|
|
|
|
|
|
|
|
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout( |
|
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): |
|
|
|
def __init__( |
|
self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane_fg_bg_plane, |
|
cls_token, |
|
normalize_feat=True, |
|
sr_ratio=2, |
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, |
|
channel_multiplier=4, |
|
**kwargs) -> None: |
|
super().__init__(vit_decoder, triplane_decoder, cls_token, |
|
normalize_feat, sr_ratio, use_fusion_blk, |
|
fusion_blk_depth, fusion_blk, channel_multiplier, |
|
**kwargs) |
|
|
|
|
|
|
|
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D( |
|
RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout |
|
): |
|
|
|
def __init__( |
|
self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane_fg_bg_plane, |
|
cls_token, |
|
normalize_feat=True, |
|
sr_ratio=2, |
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, |
|
channel_multiplier=4, |
|
**kwargs) -> None: |
|
super().__init__(vit_decoder, triplane_decoder, cls_token, |
|
normalize_feat, sr_ratio, use_fusion_blk, |
|
fusion_blk_depth, fusion_blk, channel_multiplier, |
|
**kwargs) |
|
|
|
self.decoder_pred = None |
|
self.superresolution.update( |
|
dict( |
|
conv_sr=Decoder( |
|
resolution=128, |
|
|
|
in_channels=3, |
|
|
|
ch=32, |
|
|
|
ch_mult=[1, 2, 2, 4], |
|
|
|
|
|
|
|
|
|
num_res_blocks=1, |
|
dropout=0.0, |
|
attn_resolutions=[], |
|
out_ch=32, |
|
|
|
z_channels=vit_decoder.embed_dim, |
|
|
|
), |
|
|
|
)) |
|
self.D_roll_out_input = False |
|
|
|
|
|
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
|
if self.cls_token: |
|
cls_token = latent_from_vit[:, :1] |
|
else: |
|
cls_token = None |
|
|
|
def unflatten_token(x, p=None): |
|
B, L, C = x.shape |
|
x = x.reshape(B, 3, L // 3, C) |
|
|
|
if self.cls_token: |
|
x = x[:, :, 1:] |
|
|
|
h = w = int((x.shape[2])**.5) |
|
assert h * w == x.shape[2] |
|
|
|
if p is None: |
|
x = x.reshape(shape=(B, 3, h, w, -1)) |
|
if not self.D_roll_out_input: |
|
x = rearrange( |
|
x, 'b n h w c->(b n) c h w' |
|
) |
|
else: |
|
x = rearrange( |
|
x, 'b n h w c->b c h (n w)' |
|
) |
|
else: |
|
x = x.reshape(shape=(B, 3, h, w, p, p, -1)) |
|
if self.D_roll_out_input: |
|
x = rearrange( |
|
x, 'b n h w p1 p2 c->b c (h p1) (n w p2)' |
|
) |
|
else: |
|
x = rearrange( |
|
x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' |
|
) |
|
|
|
return x |
|
|
|
latent = unflatten_token( |
|
latent_from_vit) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent = self.superresolution['conv_sr'](latent) |
|
if not self.D_roll_out_input: |
|
latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) |
|
else: |
|
latent = rearrange(latent, 'b c h (n w)->b (n c) h w', n=3) |
|
|
|
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ret_dict |
|
|
|
|
|
|
|
|
|
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder( |
|
RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D |
|
): |
|
|
|
def __init__( |
|
self, |
|
vit_decoder: VisionTransformer, |
|
triplane_decoder: Triplane_fg_bg_plane, |
|
cls_token, |
|
normalize_feat=True, |
|
sr_ratio=2, |
|
use_fusion_blk=True, |
|
fusion_blk_depth=2, |
|
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, |
|
channel_multiplier=4, |
|
**kwargs) -> None: |
|
super().__init__(vit_decoder, triplane_decoder, cls_token, |
|
normalize_feat, sr_ratio, use_fusion_blk, |
|
fusion_blk_depth, fusion_blk, channel_multiplier, |
|
patch_size=-1, |
|
**kwargs) |
|
|
|
|
|
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
|
del blk.skip_linear |
|
|
|
@torch.inference_mode() |
|
def forward_points(self, |
|
planes, |
|
points: torch.Tensor, |
|
chunk_size: int = 2**16): |
|
|
|
|
|
N, P = points.shape[:2] |
|
if planes.ndim == 4: |
|
planes = planes.reshape( |
|
len(planes), |
|
3, |
|
-1, |
|
planes.shape[-2], |
|
planes.shape[-1]) |
|
|
|
|
|
outs = [] |
|
for i in trange(0, points.shape[1], chunk_size): |
|
chunk_points = points[:, i:i + chunk_size] |
|
|
|
|
|
|
|
chunk_out = self.triplane_decoder.renderer._run_model( |
|
planes=planes, |
|
decoder=self.triplane_decoder.decoder, |
|
sample_coordinates=chunk_points, |
|
sample_directions=torch.zeros_like(chunk_points), |
|
options=self.rendering_kwargs, |
|
) |
|
|
|
|
|
outs.append(chunk_out) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
point_features = { |
|
k: torch.cat([out[k] for out in outs], dim=1) |
|
for k in outs[0].keys() |
|
} |
|
return point_features |
|
|
|
def triplane_decode_grid(self, |
|
vit_decode_out, |
|
grid_size, |
|
aabb: torch.Tensor = None, |
|
**kwargs): |
|
|
|
|
|
|
|
assert isinstance(vit_decode_out, dict) |
|
planes = vit_decode_out['latent_after_vit'] |
|
|
|
|
|
if aabb is None: |
|
if 'sampler_bbox_min' in self.rendering_kwargs: |
|
aabb = torch.tensor([ |
|
[self.rendering_kwargs['sampler_bbox_min']] * 3, |
|
[self.rendering_kwargs['sampler_bbox_max']] * 3, |
|
], |
|
device=planes.device, |
|
dtype=planes.dtype).unsqueeze(0).repeat( |
|
planes.shape[0], 1, 1) |
|
else: |
|
aabb = torch.tensor( |
|
[ |
|
[-self.rendering_kwargs['box_warp'] / 2] * 3, |
|
[self.rendering_kwargs['box_warp'] / 2] * 3, |
|
], |
|
device=planes.device, |
|
dtype=planes.dtype).unsqueeze(0).repeat( |
|
planes.shape[0], 1, 1) |
|
|
|
assert planes.shape[0] == aabb.shape[ |
|
0], "Batch size mismatch for planes and aabb" |
|
N = planes.shape[0] |
|
|
|
|
|
grid_points = [] |
|
for i in range(N): |
|
grid_points.append( |
|
torch.stack(torch.meshgrid( |
|
torch.linspace(aabb[i, 0, 0], |
|
aabb[i, 1, 0], |
|
grid_size, |
|
device=planes.device), |
|
torch.linspace(aabb[i, 0, 1], |
|
aabb[i, 1, 1], |
|
grid_size, |
|
device=planes.device), |
|
torch.linspace(aabb[i, 0, 2], |
|
aabb[i, 1, 2], |
|
grid_size, |
|
device=planes.device), |
|
indexing='ij', |
|
), |
|
dim=-1).reshape(-1, 3)) |
|
cube_grid = torch.stack(grid_points, dim=0).to(planes.device) |
|
|
|
|
|
features = self.forward_points(planes, cube_grid) |
|
|
|
|
|
features = { |
|
k: v.reshape(N, grid_size, grid_size, grid_size, -1) |
|
for k, v in features.items() |
|
} |
|
|
|
|
|
|
|
return features |
|
|
|
def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): |
|
|
|
pass |
|
|
|
def forward_vit_decoder(self, x, img_size=None): |
|
|
|
return self.vit_decoder(x) |
|
|
|
def vit_decode_backbone(self, latent, img_size): |
|
return super().vit_decode_backbone(latent, img_size) |
|
|
|
|
|
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
return super().vit_decode_postprocess(latent_from_vit, ret_dict) |
|
|
|
def vae_reparameterization(self, latent, sample_posterior): |
|
return super().vae_reparameterization(latent, sample_posterior) |
|
|