|
import math |
|
import torch |
|
import torch.nn as nn |
|
from diffusers import ModelMixin, ConfigMixin |
|
from einops import rearrange |
|
from .mv_attention import SPADTransformer as SpatialTransformer |
|
from .openaimodel import UNetModel, TimestepBlock |
|
|
|
|
|
def conv_nd(dims, *args, **kwargs): |
|
""" |
|
Create a 1D, 2D, or 3D convolution module. |
|
""" |
|
if dims == 1: |
|
return nn.Conv1d(*args, **kwargs) |
|
elif dims == 2: |
|
return nn.Conv2d(*args, **kwargs) |
|
elif dims == 3: |
|
return nn.Conv3d(*args, **kwargs) |
|
raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
|
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): |
|
""" |
|
Create sinusoidal timestep embeddings. |
|
:param timesteps: a 1-D Tensor of N indices, one per batch element. |
|
These may be fractional. |
|
:param dim: the dimension of the output. |
|
:param max_period: controls the minimum frequency of the embeddings. |
|
:return: an [N x dim] Tensor of positional embeddings. |
|
""" |
|
if not repeat_only: |
|
half = dim // 2 |
|
freqs = torch.exp( |
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half |
|
).to(device=timesteps.device) |
|
args = timesteps[:, None].float() * freqs[None] |
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
if dim % 2: |
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
|
else: |
|
embedding = repeat(timesteps, 'b -> b d', d=dim) |
|
return embedding |
|
|
|
class SPADUnetModel(UNetModel, ModelMixin, ConfigMixin): |
|
""" Modified UnetModel to support simultaneous denoising of many views. """ |
|
|
|
def __init__(self, image_size, in_channels, model_channels, out_channels, |
|
num_res_blocks, attention_resolutions, **kwargs): |
|
super().__init__(image_size=image_size, |
|
in_channels=in_channels, |
|
model_channels=model_channels, |
|
out_channels=out_channels, |
|
num_res_blocks=num_res_blocks, |
|
attention_resolutions=attention_resolutions, |
|
**kwargs) |
|
|
|
self.proj_in = nn.Conv2d( |
|
in_channels=kwargs.get("in_channels", 4), |
|
out_channels=kwargs.get("model_channels", 320), |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
) |
|
self.post_init() |
|
|
|
|
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
|
|
config = cls.load_config(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
return cls(**config) |
|
|
|
def post_init(self): |
|
assert getattr(self, "post_intialized", False) is False, "Already modified!" |
|
|
|
|
|
conv_block = self.proj_in |
|
conv_params = { |
|
k: getattr(conv_block, k) |
|
for k in [ |
|
"in_channels", |
|
"out_channels", |
|
"kernel_size", |
|
"stride", |
|
"padding", |
|
] |
|
} |
|
conv_params["in_channels"] += 6 |
|
conv_params["dims"] = 2 |
|
conv_params["device"] = conv_block.weight.device |
|
|
|
|
|
inflated_proj_in = conv_nd(**conv_params) |
|
inp_weight = conv_block.weight.data |
|
feat_shape = inp_weight.shape |
|
|
|
|
|
feat_weight = torch.zeros( |
|
(feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device |
|
) |
|
|
|
|
|
inflated_proj_in.weight.data.copy_( |
|
torch.cat([inp_weight, feat_weight], dim=1) |
|
) |
|
inflated_proj_in.bias.data.copy_(conv_block.bias.data) |
|
self.proj_in = inflated_proj_in |
|
self.post_intialized = True |
|
|
|
def encode(self, h, emb, context, blocks): |
|
hs = [] |
|
n_objects, n_views = h.shape[:2] |
|
for i, block in enumerate(blocks): |
|
for j, layer in enumerate(block): |
|
if isinstance(layer, SpatialTransformer): |
|
h = layer(h, context) |
|
elif isinstance(layer, TimestepBlock): |
|
|
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
emb = rearrange(emb, "n v c -> (n v) c") |
|
|
|
h = layer(h, emb) |
|
|
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views) |
|
else: |
|
|
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
|
|
h = layer(h) |
|
|
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
|
|
if h.isnan().any(): |
|
breakpoint() |
|
|
|
hs.append(h) |
|
return hs |
|
|
|
def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False): |
|
ho = [] |
|
n_objects, n_views = h.shape[:2] |
|
for i, block in enumerate(self.output_blocks): |
|
h = torch.cat([h, hs[-(i+1)]], dim=2) |
|
for j, layer in enumerate(block): |
|
if isinstance(layer, SpatialTransformer): |
|
h = layer(h, context) |
|
|
|
elif isinstance(layer, TimestepBlock): |
|
|
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
emb = rearrange(emb, "n v c -> (n v) c") |
|
|
|
h = layer(h, emb) |
|
|
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views) |
|
else: |
|
|
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
|
|
h = layer(h) |
|
|
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
ho.append(h) |
|
|
|
|
|
h = h.type(xdtype) |
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
if last: |
|
if self.predict_codebook_ids: |
|
|
|
h = self.id_predictor(h) |
|
else: |
|
h = self.out(h) |
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
|
|
ho.append(h) |
|
return ho if return_outputs else h |
|
|
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): |
|
""" |
|
objaverse batch: |
|
# img (x): [n_objects, n_views, 4, 64, 64] |
|
# timesteps (timesteps): [n_objects, n_views] |
|
# txt (context[0]): [n_objects, n_views, max_seq_len, 768] |
|
# cam (context[1]): [n_objects, n_views, 1280] |
|
|
|
laion batch: |
|
# img (x): [batch_size, 1, 4, 64, 64] |
|
# timesteps (timesteps): [batch_size, 1, 1] |
|
# txt (context[0]): [batch_size, 1, max_seq_len, 768] |
|
# cam (context[1]): [batch_size, 1, 1280] * 0.0 |
|
|
|
:return: an [n_objects, n_views, 4, 64, 64] |
|
""" |
|
n_objects, n_views = x.shape[:2] |
|
|
|
|
|
timesteps = rearrange(timesteps, "n v -> (n v)") |
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) |
|
time = self.time_embed(t_emb) |
|
print("old time: ", time.shape) |
|
time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views) |
|
|
|
|
|
|
|
if len(context) == 2: |
|
txt, cam = context |
|
print("txt shape", txt.shape) |
|
elif len(context) == 3: |
|
txt, cam, epi_mask = context |
|
txt = (txt, epi_mask) |
|
else: |
|
raise ValueError |
|
|
|
|
|
if x.shape[2] > 4: |
|
plucker, x = x[:, :, 4:], x[:, :, :4] |
|
txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker) |
|
print("extracted") |
|
|
|
|
|
|
|
time_cam = time |
|
del time, cam |
|
|
|
|
|
|
|
h = x.type(self.dtype) |
|
print("h: ", h.shape) |
|
print("time_cam: ", time_cam.shape) |
|
|
|
hs = self.encode(h, time_cam, txt, self.input_blocks) |
|
|
|
|
|
h = self.encode(hs[-1], time_cam, txt, [self.middle_block])[0] |
|
|
|
|
|
h = self.decode(h, hs, time_cam, txt, x.dtype, last=True) |
|
|
|
|
|
return h |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|