import math import torch import torch.nn as nn from diffusers import ModelMixin, ConfigMixin from einops import rearrange from .mv_attention import SPADTransformer as SpatialTransformer from .openaimodel import UNetModel, TimestepBlock def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") # we define the timestep_embedding def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding class SPADUnetModel(UNetModel, ModelMixin, ConfigMixin): """ Modified UnetModel to support simultaneous denoising of many views. """ def __init__(self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, **kwargs): super().__init__(image_size=image_size, in_channels=in_channels, model_channels=model_channels, out_channels=out_channels, num_res_blocks=num_res_blocks, attention_resolutions=attention_resolutions, **kwargs) self.proj_in = nn.Conv2d( in_channels=kwargs.get("in_channels", 4), out_channels=kwargs.get("model_channels", 320), kernel_size=3, stride=1, padding=1 ) self.post_init() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): # load the config config = cls.load_config(pretrained_model_name_or_path, **kwargs) # pass the config parameters to the __init__ method return cls(**config) def post_init(self): assert getattr(self, "post_intialized", False) is False, "Already modified!" # Inflate input conv block to attach plucker coordinates conv_block = self.proj_in conv_params = { k: getattr(conv_block, k) for k in [ "in_channels", "out_channels", "kernel_size", "stride", "padding", ] } conv_params["in_channels"] += 6 conv_params["dims"] = 2 conv_params["device"] = conv_block.weight.device # Copy original weights for input conv block inflated_proj_in = conv_nd(**conv_params) inp_weight = conv_block.weight.data feat_shape = inp_weight.shape # Initialize new weights for plucker coordinates as zeros feat_weight = torch.zeros( (feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device ) # Assemble new weights and bias inflated_proj_in.weight.data.copy_( torch.cat([inp_weight, feat_weight], dim=1) ) inflated_proj_in.bias.data.copy_(conv_block.bias.data) self.proj_in = inflated_proj_in self.post_intialized = True def encode(self, h, emb, context, blocks): hs = [] n_objects, n_views = h.shape[:2] for i, block in enumerate(blocks): for j, layer in enumerate(block): if isinstance(layer, SpatialTransformer): h = layer(h, context) elif isinstance(layer, TimestepBlock): # squash first two dims (single pass) h = rearrange(h, "n v c h w -> (n v) c h w") emb = rearrange(emb, "n v c -> (n v) c") # apply layer h = layer(h, emb) # unsquash first two dims h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views) else: # squash first two dims (single pass) h = rearrange(h, "n v c h w -> (n v) c h w") # apply layer h = layer(h) # unsquash first two dims h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) if h.isnan().any(): breakpoint() hs.append(h) return hs def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False): ho = [] n_objects, n_views = h.shape[:2] for i, block in enumerate(self.output_blocks): h = torch.cat([h, hs[-(i+1)]], dim=2) for j, layer in enumerate(block): if isinstance(layer, SpatialTransformer): h = layer(h, context) elif isinstance(layer, TimestepBlock): # squash first two dims (single pass) h = rearrange(h, "n v c h w -> (n v) c h w") emb = rearrange(emb, "n v c -> (n v) c") # apply layer h = layer(h, emb) # unsquash first two dims h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views) else: # squash first two dims (single pass) h = rearrange(h, "n v c h w -> (n v) c h w") # apply layer h = layer(h) # unsquash first two dims h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) ho.append(h) # process last layer h = h.type(xdtype) h = rearrange(h, "n v c h w -> (n v) c h w") if last: if self.predict_codebook_ids: # not used in vae h = self.id_predictor(h) else: h = self.out(h) h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) ho.append(h) return ho if return_outputs else h def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ objaverse batch: # img (x): [n_objects, n_views, 4, 64, 64] # timesteps (timesteps): [n_objects, n_views] # txt (context[0]): [n_objects, n_views, max_seq_len, 768] # cam (context[1]): [n_objects, n_views, 1280] laion batch: # img (x): [batch_size, 1, 4, 64, 64] # timesteps (timesteps): [batch_size, 1, 1] # txt (context[0]): [batch_size, 1, max_seq_len, 768] # cam (context[1]): [batch_size, 1, 1280] * 0.0 :return: an [n_objects, n_views, 4, 64, 64] """ n_objects, n_views = x.shape[:2] # timsteps embedding timesteps = rearrange(timesteps, "n v -> (n v)") t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) time = self.time_embed(t_emb) print("old time: ", time.shape) time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views) # 2, 4, 1280 # extract txt and cam embedding (absolute) from context if len(context) == 2: txt, cam = context print("txt shape", txt.shape) elif len(context) == 3: txt, cam, epi_mask = context txt = (txt, epi_mask) else: raise ValueError # extract plucker embedding from x if x.shape[2] > 4: plucker, x = x[:, :, 4:], x[:, :, :4] txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker) print("extracted") # print("txt shape: ", txt.shape) # combine timestep and camera embedding (resnet) time_cam = time # add + cam later del time, cam # encode h = x.type(self.dtype) print("h: ", h.shape) print("time_cam: ", time_cam.shape) # print("txt: ", txt.shape) hs = self.encode(h, time_cam, txt, self.input_blocks) # middle block h = self.encode(hs[-1], time_cam, txt, [self.middle_block])[0] # decode h = self.decode(h, hs, time_cam, txt, x.dtype, last=True) # concat along channel dim return h # if __name__ == "__main__": # model_args = { # "image_size": 32, # unused # "in_channels": 4, # "out_channels": 4, # "model_channels": 320, # "attention_resolutions": [ 4, 2, 1 ], # "num_res_blocks": 2, # "channel_mult": [ 1, 2, 4, 4 ], # "num_heads": 8, # "use_spatial_transformer": True, # "transformer_depth": 1, # "context_dim": 768, # "use_checkpoint": False, # "legacy": False, # } # # manyviews unet # model = SPADUnetModel(**model_args) # model.eval() # n_objects = 2; n_views = 3 # # img (z): [n_objects, n_views, 4, 64, 64] # # txt (c): [n_objects, n_views, max_seq_len, 768] # # cam (T): [n_objects, n_views, 1280] # x = torch.randn(n_objects, n_views, 10, 32, 32) # timesteps = torch.randint(0, 1000, (n_objects, n_views, )).long() # context = [ # torch.randn(n_objects, n_views, 77, 768), # torch.randn(n_objects, n_views, 1280), # torch.ones(n_objects, n_views * 32 * 32, n_views * 32 * 32, dtype=torch.bool) # ] # context[-1][0] = False # out = model(x, timesteps=timesteps, context=context)