File size: 5,378 Bytes
c56ba76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
class ModLN(nn.Module):
"""
Modulation with adaLN.
References:
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
"""
def __init__(self, inner_dim: int, mod_dim: int, eps: float):
super().__init__()
self.norm = nn.LayerNorm(inner_dim, eps=eps)
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(mod_dim, inner_dim * 2),
)
@staticmethod
def modulate(x, shift, scale):
# x: [N, L, D]
# shift, scale: [N, D]
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def forward(self, x, cond):
shift, scale = self.mlp(cond).chunk(2, dim=-1) # [N, D]
return self.modulate(self.norm(x), shift, scale) # [N, L, D]
class ConditionModulationBlock(nn.Module):
"""
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
"""
# use attention from torch.nn.MultiHeadAttention
# Block contains a cross-attention layer, a self-attention layer, and a MLP
def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
attn_drop: float = 0., attn_bias: bool = False,
mlp_ratio: float = 4., mlp_drop: float = 0.):
super().__init__()
self.norm1 = ModLN(inner_dim, mod_dim, eps)
self.cross_attn = nn.MultiheadAttention(
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
dropout=attn_drop, bias=attn_bias, batch_first=True)
self.norm2 = ModLN(inner_dim, mod_dim, eps)
self.self_attn = nn.MultiheadAttention(
embed_dim=inner_dim, num_heads=num_heads,
dropout=attn_drop, bias=attn_bias, batch_first=True)
self.norm3 = ModLN(inner_dim, mod_dim, eps)
self.mlp = nn.Sequential(
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(mlp_drop),
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
nn.Dropout(mlp_drop),
)
def forward(self, x, cond, mod):
# x: [N, L, D]
# cond: [N, L_cond, D_cond]
# mod: [N, D_mod]
x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
before_sa = self.norm2(x, mod)
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
x = x + self.mlp(self.norm3(x, mod))
return x
class TriplaneTransformer(nn.Module):
"""
Transformer with condition and modulation that generates a triplane representation.
Reference:
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
"""
def __init__(self, inner_dim: int, image_feat_dim: int, camera_embed_dim: int,
triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
num_layers: int, num_heads: int,
eps: float = 1e-6):
super().__init__()
# attributes
self.triplane_low_res = triplane_low_res
self.triplane_high_res = triplane_high_res
self.triplane_dim = triplane_dim
# modules
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
self.layers = nn.ModuleList([
ConditionModulationBlock(
inner_dim=inner_dim, cond_dim=image_feat_dim, mod_dim=camera_embed_dim, num_heads=num_heads, eps=eps)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(inner_dim, eps=eps)
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
def forward(self, image_feats, camera_embeddings):
# image_feats: [N, L_cond, D_cond]
# camera_embeddings: [N, D_mod]
assert image_feats.shape[0] == camera_embeddings.shape[0], \
f"Mismatched batch size: {image_feats.shape[0]} vs {camera_embeddings.shape[0]}"
N = image_feats.shape[0]
H = W = self.triplane_low_res
L = 3 * H * W
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
for layer in self.layers:
x = layer(x, image_feats, camera_embeddings)
x = self.norm(x)
# separate each plane and apply deconv
x = x.view(N, 3, H, W, -1)
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
x = self.deconv(x) # [3*N, D', H', W']
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
x = x.contiguous()
assert self.triplane_high_res == x.shape[-2], \
f"Output triplane resolution does not match with expected: {x.shape[-2]} vs {self.triplane_high_res}"
assert self.triplane_dim == x.shape[-3], \
f"Output triplane dimension does not match with expected: {x.shape[-3]} vs {self.triplane_dim}"
return x |