spad / unet /mv_unet.py
jadechoghari's picture
Update unet/mv_unet.py
0ea7028 verified
raw
history blame
10.5 kB
import math
import torch
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin
from einops import rearrange
from .mv_attention import SPADTransformer as SpatialTransformer
from .openaimodel import UNetModel, TimestepBlock
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
# we define the timestep_embedding
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
class SPADUnetModel(UNetModel, ModelMixin, ConfigMixin):
""" Modified UnetModel to support simultaneous denoising of many views. """
def __init__(self, image_size, in_channels, model_channels, out_channels,
num_res_blocks, attention_resolutions, **kwargs):
super().__init__(image_size=image_size,
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
**kwargs)
self.proj_in = nn.Conv2d(
in_channels=kwargs.get("in_channels", 4),
out_channels=kwargs.get("model_channels", 320),
kernel_size=3,
stride=1,
padding=1
)
self.post_init()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# load the config
config = cls.load_config(pretrained_model_name_or_path, **kwargs)
# pass the config parameters to the __init__ method
return cls(**config)
def post_init(self):
assert getattr(self, "post_intialized", False) is False, "Already modified!"
# Inflate input conv block to attach plucker coordinates
conv_block = self.proj_in
conv_params = {
k: getattr(conv_block, k)
for k in [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
]
}
conv_params["in_channels"] += 6
conv_params["dims"] = 2
conv_params["device"] = conv_block.weight.device
# Copy original weights for input conv block
inflated_proj_in = conv_nd(**conv_params)
inp_weight = conv_block.weight.data
feat_shape = inp_weight.shape
# Initialize new weights for plucker coordinates as zeros
feat_weight = torch.zeros(
(feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device
)
# Assemble new weights and bias
inflated_proj_in.weight.data.copy_(
torch.cat([inp_weight, feat_weight], dim=1)
)
inflated_proj_in.bias.data.copy_(conv_block.bias.data)
self.proj_in = inflated_proj_in
self.post_intialized = True
def encode(self, h, emb, context, blocks):
hs = []
n_objects, n_views = h.shape[:2]
for i, block in enumerate(blocks):
for j, layer in enumerate(block):
if isinstance(layer, SpatialTransformer):
h = layer(h, context)
elif isinstance(layer, TimestepBlock):
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
emb = rearrange(emb, "n v c -> (n v) c")
# apply layer
h = layer(h, emb)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
else:
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
# apply layer
h = layer(h)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
if h.isnan().any():
breakpoint()
hs.append(h)
return hs
def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False):
ho = []
n_objects, n_views = h.shape[:2]
for i, block in enumerate(self.output_blocks):
h = torch.cat([h, hs[-(i+1)]], dim=2)
for j, layer in enumerate(block):
if isinstance(layer, SpatialTransformer):
h = layer(h, context)
elif isinstance(layer, TimestepBlock):
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
emb = rearrange(emb, "n v c -> (n v) c")
# apply layer
h = layer(h, emb)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
else:
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
# apply layer
h = layer(h)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
ho.append(h)
# process last layer
h = h.type(xdtype)
h = rearrange(h, "n v c h w -> (n v) c h w")
if last:
if self.predict_codebook_ids:
# not used in vae
h = self.id_predictor(h)
else:
h = self.out(h)
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
ho.append(h)
return ho if return_outputs else h
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
objaverse batch:
# img (x): [n_objects, n_views, 4, 64, 64]
# timesteps (timesteps): [n_objects, n_views]
# txt (context[0]): [n_objects, n_views, max_seq_len, 768]
# cam (context[1]): [n_objects, n_views, 1280]
laion batch:
# img (x): [batch_size, 1, 4, 64, 64]
# timesteps (timesteps): [batch_size, 1, 1]
# txt (context[0]): [batch_size, 1, max_seq_len, 768]
# cam (context[1]): [batch_size, 1, 1280] * 0.0
:return: an [n_objects, n_views, 4, 64, 64]
"""
n_objects, n_views = x.shape[:2]
# timsteps embedding
timesteps = rearrange(timesteps, "n v -> (n v)")
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
time = self.time_embed(t_emb)
print("old time: ", time.shape)
time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)
# 2, 4, 1280
# extract txt and cam embedding (absolute) from context
if len(context) == 2:
txt, cam = context
print("txt shape", txt.shape)
elif len(context) == 3:
txt, cam, epi_mask = context
txt = (txt, epi_mask)
else:
raise ValueError
# extract plucker embedding from x
if x.shape[2] > 4:
plucker, x = x[:, :, 4:], x[:, :, :4]
txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)
print("extracted")
# print("txt shape: ", txt.shape)
# combine timestep and camera embedding (resnet)
time_cam = time # add + cam later
del time, cam
# encode
h = x.type(self.dtype)
print("h: ", h.shape)
print("time_cam: ", time_cam.shape)
# print("txt: ", txt.shape)
hs = self.encode(h, time_cam, txt, self.input_blocks)
# middle block
h = self.encode(hs[-1], time_cam, txt, [self.middle_block])[0]
# decode
h = self.decode(h, hs, time_cam, txt, x.dtype, last=True)
# concat along channel dim
return h
# if __name__ == "__main__":
# model_args = {
# "image_size": 32, # unused
# "in_channels": 4,
# "out_channels": 4,
# "model_channels": 320,
# "attention_resolutions": [ 4, 2, 1 ],
# "num_res_blocks": 2,
# "channel_mult": [ 1, 2, 4, 4 ],
# "num_heads": 8,
# "use_spatial_transformer": True,
# "transformer_depth": 1,
# "context_dim": 768,
# "use_checkpoint": False,
# "legacy": False,
# }
# # manyviews unet
# model = SPADUnetModel(**model_args)
# model.eval()
# n_objects = 2; n_views = 3
# # img (z): [n_objects, n_views, 4, 64, 64]
# # txt (c): [n_objects, n_views, max_seq_len, 768]
# # cam (T): [n_objects, n_views, 1280]
# x = torch.randn(n_objects, n_views, 10, 32, 32)
# timesteps = torch.randint(0, 1000, (n_objects, n_views, )).long()
# context = [
# torch.randn(n_objects, n_views, 77, 768),
# torch.randn(n_objects, n_views, 1280),
# torch.ones(n_objects, n_views * 32 * 32, n_views * 32 * 32, dtype=torch.bool)
# ]
# context[-1][0] = False
# out = model(x, timesteps=timesteps, context=context)