import torch import torch.nn as nn from ldm.modules.attention import BasicTransformerBlock from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder import torch.nn.functional as F from ..attention import SelfAttention, FeedForward from .convnext import convnext_tiny class PositionNet(nn.Module): def __init__(self, resize_input=448, out_dim=768): super().__init__() self.resize_input = resize_input self.down_factor = 32 # determined by the convnext backbone self.out_dim = out_dim assert self.resize_input % self.down_factor == 0 self.convnext_tiny_backbone = convnext_tiny(pretrained=True) self.num_tokens = (self.resize_input // self.down_factor) ** 2 convnext_feature_dim = 768 self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT self.linears = nn.Sequential( nn.Linear( convnext_feature_dim, 512), nn.SiLU(), nn.Linear( 512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim])) def forward(self, normal, mask): B = normal.shape[0] # token from edge map normal = torch.nn.functional.interpolate(normal, self.resize_input) normal_feature = self.convnext_tiny_backbone(normal) objs = normal_feature.reshape(B, -1, self.num_tokens) objs = objs.permute(0, 2, 1) # N*Num_tokens*dim # expand null token null_objs = self.null_feature.view(1,1,-1) null_objs = null_objs.repeat(B,self.num_tokens,1) # mask replacing mask = mask.view(-1,1,1) objs = objs*mask + null_objs*(1-mask) # add pos objs = objs + self.pos_embedding # fuse them objs = self.linears(objs) assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim]) return objs