|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import collections |
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
|
|
from einops import rearrange |
|
from torch import nn |
|
|
|
|
|
from diffusers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class CombinedTimestepSizeEmbeddings(nn.Module): |
|
""" |
|
For PixArt-Alpha. |
|
|
|
Reference: |
|
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 |
|
""" |
|
|
|
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): |
|
super().__init__() |
|
|
|
self.outdim = size_emb_dim |
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
|
|
|
self.use_additional_conditions = use_additional_conditions |
|
if use_additional_conditions: |
|
self.use_additional_conditions = True |
|
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
|
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) |
|
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) |
|
|
|
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): |
|
if size.ndim == 1: |
|
size = size[:, None] |
|
|
|
if size.shape[0] != batch_size: |
|
size = size.repeat(batch_size // size.shape[0], 1) |
|
if size.shape[0] != batch_size: |
|
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") |
|
|
|
current_batch_size, dims = size.shape[0], size.shape[1] |
|
size = size.reshape(-1) |
|
size_freq = self.additional_condition_proj(size).to(size.dtype) |
|
|
|
size_emb = embedder(size_freq) |
|
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) |
|
return size_emb |
|
|
|
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): |
|
timesteps_proj = self.time_proj(timestep) |
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) |
|
|
|
if self.use_additional_conditions: |
|
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) |
|
aspect_ratio = self.apply_condition( |
|
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder |
|
) |
|
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) |
|
else: |
|
conditioning = timesteps_emb |
|
|
|
return conditioning |
|
|
|
class PatchEmbed2D(nn.Module): |
|
"""2D Image to Patch Embedding""" |
|
|
|
def __init__( |
|
self, |
|
num_frames=1, |
|
height=224, |
|
width=224, |
|
patch_size_t=1, |
|
patch_size=16, |
|
in_channels=3, |
|
embed_dim=768, |
|
layer_norm=False, |
|
flatten=True, |
|
bias=True, |
|
interpolation_scale=(1, 1), |
|
interpolation_scale_t=1, |
|
use_abs_pos=False, |
|
): |
|
super().__init__() |
|
self.use_abs_pos = use_abs_pos |
|
self.flatten = flatten |
|
self.layer_norm = layer_norm |
|
|
|
self.proj = nn.Conv2d( |
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias |
|
) |
|
if layer_norm: |
|
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) |
|
else: |
|
self.norm = None |
|
|
|
self.patch_size_t = patch_size_t |
|
self.patch_size = patch_size |
|
|
|
def forward(self, latent): |
|
b, _, _, _, _ = latent.shape |
|
video_latent = None |
|
|
|
latent = rearrange(latent, 'b c t h w -> (b t) c h w') |
|
|
|
latent = self.proj(latent) |
|
if self.flatten: |
|
latent = latent.flatten(2).transpose(1, 2) |
|
if self.layer_norm: |
|
latent = self.norm(latent) |
|
|
|
latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b) |
|
video_latent = latent |
|
|
|
return video_latent |
|
|
|
|