# Adapted from Open-Sora-Plan # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan # -------------------------------------------------------- import torch import collections from diffusers.models.embeddings import TimestepEmbedding, Timesteps from einops import rearrange from torch import nn from diffusers.utils import logging logger = logging.get_logger(__name__) class CombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. Reference: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.use_additional_conditions = use_additional_conditions if use_additional_conditions: self.use_additional_conditions = True self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): if size.ndim == 1: size = size[:, None] if size.shape[0] != batch_size: size = size.repeat(batch_size // size.shape[0], 1) if size.shape[0] != batch_size: raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") current_batch_size, dims = size.shape[0], size.shape[1] size = size.reshape(-1) size_freq = self.additional_condition_proj(size).to(size.dtype) size_emb = embedder(size_freq) size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) return size_emb def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) if self.use_additional_conditions: resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) aspect_ratio = self.apply_condition( aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder ) conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) else: conditioning = timesteps_emb return conditioning class PatchEmbed2D(nn.Module): """2D Image to Patch Embedding""" def __init__( self, num_frames=1, height=224, width=224, patch_size_t=1, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, interpolation_scale=(1, 1), interpolation_scale_t=1, use_abs_pos=False, ): super().__init__() self.use_abs_pos = use_abs_pos self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.patch_size_t = patch_size_t self.patch_size = patch_size def forward(self, latent): b, _, _, _, _ = latent.shape video_latent = None latent = rearrange(latent, 'b c t h w -> (b t) c h w') latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C if self.layer_norm: latent = self.norm(latent) latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b) video_latent = latent return video_latent