import math from typing import Optional import numpy as np import torch import torch.nn.functional as F import torch.nn.init as init from einops import rearrange from torch import nn def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, int): grid_size = (grid_size, grid_size) grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class Patch1D(nn.Module): def __init__( self, channels: int, use_conv: bool = False, out_channels: Optional[int] = None, stride: int = 2, padding: int = 0, name: str = "conv", ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding self.name = name if use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding) init.constant_(self.conv.weight, 0.0) with torch.no_grad(): for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride init.constant_(self.conv.bias, 0.0) else: assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels return self.conv(inputs) class UnPatch1D(nn.Module): def __init__( self, channels: int, use_conv: bool = False, use_conv_transpose: bool = False, out_channels: Optional[int] = None, name: str = "conv", ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name self.conv = None if use_conv_transpose: self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels if self.use_conv_transpose: return self.conv(inputs) outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") if self.use_conv: outputs = self.conv(outputs) return outputs class Upsampler(nn.Module): def __init__( self, spatial_upsample_factor: int = 1, temporal_upsample_factor: int = 1, ): super().__init__() self.spatial_upsample_factor = spatial_upsample_factor self.temporal_upsample_factor = temporal_upsample_factor class TemporalUpsampler3D(Upsampler): def __init__(self): super().__init__( spatial_upsample_factor=1, temporal_upsample_factor=2, ) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.shape[2] > 1: first_frame, x = x[:, :, :1], x[:, :, 1:] x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear") x = torch.cat([first_frame, x], dim=2) return x class CausalConv3d(nn.Conv3d): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, # : int | tuple[int, int, int], stride=1, # : int | tuple[int, int, int] = 1, padding=1, # : int | tuple[int, int, int], # TODO: change it to 0. dilation=1, # : int | tuple[int, int, int] = 1, **kwargs, ): kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3 assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead." stride = stride if isinstance(stride, tuple) else (stride,) * 3 assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead." dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3 assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead." t_ks, h_ks, w_ks = kernel_size _, h_stride, w_stride = stride t_dilation, h_dilation, w_dilation = dilation t_pad = (t_ks - 1) * t_dilation # TODO: align with SD if padding is None: h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2) w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2) elif isinstance(padding, int): h_pad = w_pad = padding else: assert NotImplementedError self.temporal_padding = t_pad super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=(0, h_pad, w_pad), **kwargs, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, C, T, H, W) x = F.pad( x, pad=(0, 0, 0, 0, self.temporal_padding, 0), mode="replicate", # TODO: check if this is necessary ) return super().forward(x) class PatchEmbed3D(nn.Module): """3D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, time_patch_size=4, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, interpolation_scale=1, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv3d( in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.patch_size = patch_size # See: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size self.interpolation_scale = interpolation_scale pos_embed = get_2d_sincos_pos_embed( embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale ) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size latent = self.proj(latent) latent = rearrange(latent, "b c f h w -> (b f) c h w") if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC if self.layer_norm: latent = self.norm(latent) # Interpolate positional embeddings if needed. # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) if self.height != height or self.width != width: pos_embed = get_2d_sincos_pos_embed( embed_dim=self.pos_embed.shape[-1], grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, ) pos_embed = torch.from_numpy(pos_embed) pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) else: pos_embed = self.pos_embed return (latent + pos_embed).to(latent.dtype) class PatchEmbedF3D(nn.Module): """Fake 3D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, interpolation_scale=1, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) self.proj_t = Patch1D( embed_dim, True, stride=patch_size ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.patch_size = patch_size # See: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size self.interpolation_scale = interpolation_scale pos_embed = get_2d_sincos_pos_embed( embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale ) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size b, c, f, h, w = latent.size() latent = rearrange(latent, "b c f h w -> (b f) c h w") latent = self.proj(latent) latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f) latent = rearrange(latent, "b c f h w -> (b h w) c f") latent = self.proj_t(latent) latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2) latent = rearrange(latent, "b c f h w -> (b f) c h w") if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC if self.layer_norm: latent = self.norm(latent) # Interpolate positional embeddings if needed. # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) if self.height != height or self.width != width: pos_embed = get_2d_sincos_pos_embed( embed_dim=self.pos_embed.shape[-1], grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, ) pos_embed = torch.from_numpy(pos_embed) pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) else: pos_embed = self.pos_embed return (latent + pos_embed).to(latent.dtype) class CasualPatchEmbed3D(nn.Module): """3D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, time_patch_size=4, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, interpolation_scale=1, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = CausalConv3d( in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias, padding=None ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.patch_size = patch_size # See: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size self.interpolation_scale = interpolation_scale pos_embed = get_2d_sincos_pos_embed( embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale ) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size latent = self.proj(latent) latent = rearrange(latent, "b c f h w -> (b f) c h w") if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC if self.layer_norm: latent = self.norm(latent) # Interpolate positional embeddings if needed. # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) if self.height != height or self.width != width: pos_embed = get_2d_sincos_pos_embed( embed_dim=self.pos_embed.shape[-1], grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, ) pos_embed = torch.from_numpy(pos_embed) pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) else: pos_embed = self.pos_embed return (latent + pos_embed).to(latent.dtype)