''' Codes are from: https://github.com/jaxony/unet-pytorch/blob/master/model.py ''' import torch import torch.nn as nn from diffusers import UNet2DModel import einops class UNetPP(nn.Module): ''' Wrapper for UNet in diffusers ''' def __init__(self, in_channels): super(UNetPP, self).__init__() self.in_channels = in_channels self.unet = UNet2DModel( sample_size=[256, 256*3], in_channels=in_channels, out_channels=32, layers_per_block=2, block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4), down_block_types=( "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "DownBlock2D", ), up_block_types=( "UpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", ), ) self.unet.enable_xformers_memory_efficient_attention() if in_channels > 12: self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3])) def forward(self, x, t=256): learned_plane = self.learned_plane if x.shape[1] < self.in_channels: learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device) x = torch.cat([x, learned_plane], dim = 1) return self.unet(x, t).sample