import torch import torch.nn as nn import torch.nn.functional as F from dotmap import DotMap from salad.model_components.simple_module import TimePointWiseEncoder, TimestepEmbedder from salad.model_components.transformer import ( PositionalEncoding, TimeTransformerDecoder, TimeTransformerEncoder, ) class UnCondDiffNetwork(nn.Module): def __init__(self, input_dim, residual, **kwargs): """ Transformer Encoder. """ super().__init__() self.input_dim = input_dim self.residual = residual self.__dict__.update(kwargs) self.hparams = DotMap(self.__dict__) self._build_model() def _build_model(self): self.act = F.leaky_relu if self.hparams.get("use_timestep_embedder"): self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim) dim_ctx = self.hparams.timestep_embedder_dim else: dim_ctx = 3 """ Encoder part """ enc_dim = self.hparams.embedding_dim self.embedding = nn.Linear(self.hparams.input_dim, enc_dim) if not self.hparams.get("encoder_type"): self.encoder = TimeTransformerEncoder( enc_dim, dim_ctx=dim_ctx, num_heads=self.hparams.num_heads if self.hparams.get("num_heads") else 4, use_time=True, num_layers=self.hparams.enc_num_layers, last_fc=True, last_fc_dim_out=self.hparams.input_dim, ) else: if self.hparams.encoder_type == "transformer": self.encoder = TimeTransformerEncoder( enc_dim, dim_ctx=dim_ctx, num_heads=self.hparams.num_heads if self.hparams.get("num_heads") else 4, use_time=True, num_layers=self.hparams.enc_num_layers, last_fc=True, last_fc_dim_out=self.hparams.input_dim, dropout=self.hparams.get("attn_dropout", 0.0) ) else: raise ValueError def forward(self, x, beta): """ Input: x: [B,G,D] latent beta: B Output: eta: [B,G,D] """ B, G = x.shape[:2] if self.hparams.get("use_timestep_embedder"): time_emb = self.time_embedder(beta).unsqueeze(1) else: beta = beta.view(B, 1, 1) time_emb = torch.cat( [beta, torch.sin(beta), torch.cos(beta)], dim=-1 ) # [B,1,3] ctx = time_emb x_emb = self.embedding(x) out = self.encoder(x_emb, ctx=ctx) if self.hparams.residual: out = out + x return out class CondDiffNetwork(nn.Module): def __init__(self, input_dim, residual, **kwargs): """ Transformer Encoder + Decoder. """ super().__init__() self.input_dim = input_dim self.residual = residual self.__dict__.update(kwargs) self.hparams = DotMap(self.__dict__) self._build_model() def _build_model(self): self.act = F.leaky_relu if self.hparams.get("use_timestep_embedder"): self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim) dim_ctx = self.hparams.timestep_embedder_dim else: dim_ctx = 3 """ Encoder part """ enc_dim = self.hparams.context_embedding_dim self.context_embedding = nn.Linear(self.hparams.context_dim, enc_dim) if self.hparams.encoder_type == "transformer": self.encoder = TimeTransformerEncoder( enc_dim, 3, num_heads=4, use_time=self.hparams.encoder_use_time, num_layers=self.hparams.enc_num_layers if self.hparams.get("enc_num_layers") else 3, last_fc=False, ) elif self.hparams.encoder_type == "pointwise": self.encoder = TimePointWiseEncoder( enc_dim, dim_ctx=None, use_time=self.hparams.encoder_use_time, num_layers=self.hparams.enc_num_layers, ) else: raise ValueError """ Decoder part """ dec_dim = self.hparams.embedding_dim input_dim = self.hparams.input_dim self.query_embedding = nn.Linear(self.hparams.input_dim, dec_dim) if self.hparams.decoder_type == "transformer_decoder": self.decoder = TimeTransformerDecoder( dec_dim, enc_dim, dim_ctx=dim_ctx, num_heads=4, last_fc=True, last_fc_dim_out=input_dim, num_layers=self.hparams.dec_num_layers if self.hparams.get("dec_num_layers") else 3, ) elif self.hparams.decoder_type == "transformer_encoder": self.decoder = TimeTransformerEncoder( dec_dim, dim_ctx=enc_dim + dim_ctx, num_heads=4, last_fc=True, last_fc_dim_out=input_dim, num_layers=self.hparams.dec_num_layers if self.hparams.get("dec_num_layers") else 3, ) else: raise ValueError def forward(self, x, beta, context): """ Input: x: [B,G,D] intrinsic beta: B context: [B,G,D2] or [B, D2] condition Output: eta: [B,G,D] """ # print(f"x: {x.shape} context: {context.shape} beta: {beta.shape}") B, G = x.shape[:2] if self.hparams.get("use_timestep_embedder"): time_emb = self.time_embedder(beta).unsqueeze(1) else: beta = beta.view(B, 1, 1) time_emb = torch.cat( [beta, torch.sin(beta), torch.cos(beta)], dim=-1 ) # [B,1,3] ctx = time_emb """ Encoding """ cout = self.context_embedding(context) cout = self.encoder(cout, ctx=ctx if self.hparams.encoder_use_time else None) if cout.ndim == 2: cout = cout.unsqueeze(1).expand(-1, G, -1) """ Decoding """ out = self.query_embedding(x) if self.hparams.get("use_pos_encoding"): out = self.pos_encoding(out) if self.hparams.decoder_type == "transformer_encoder": try: ctx = ctx.expand(-1, G, -1) if cout.ndim == 2: cout = cout.unsqueeze(1) cout = cout.expand(-1, G, -1) ctx = torch.cat([ctx, cout], -1) except Exception as e: print(e, G, ctx.shape, cout.shape) out = self.decoder(out, ctx=ctx) else: out = self.decoder(out, cout, ctx=ctx) # if hasattr(self, "last_fc"): # out = self.last_fc(out) if self.hparams.residual: out = out + x return out