Spaces:
Running
Running
from typing import Any, Dict, Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
from diffusers.models.embeddings import (CombinedTimestepLabelEmbeddings, | |
TimestepEmbedding, Timesteps) | |
from torch import nn | |
def zero_module(module): | |
# Zero out the parameters of a module and return it. | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
class FP32LayerNorm(nn.LayerNorm): | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
origin_dtype = inputs.dtype | |
if hasattr(self, 'weight') and self.weight is not None: | |
return F.layer_norm( | |
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps | |
).to(origin_dtype) | |
else: | |
return F.layer_norm( | |
inputs.float(), self.normalized_shape, None, None, self.eps | |
).to(origin_dtype) | |
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): | |
""" | |
For PixArt-Alpha. | |
Reference: | |
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 | |
""" | |
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): | |
super().__init__() | |
self.outdim = size_emb_dim | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.use_additional_conditions = use_additional_conditions | |
if use_additional_conditions: | |
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2) | |
self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2) | |
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
if self.use_additional_conditions: | |
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) | |
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) | |
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) | |
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) | |
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) | |
else: | |
conditioning = timesteps_emb | |
return conditioning | |
class AdaLayerNormSingle(nn.Module): | |
r""" | |
Norm layer adaptive layer norm single (adaLN-single). | |
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
use_additional_conditions (`bool`): To use additional conditions for normalization or not. | |
""" | |
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): | |
super().__init__() | |
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( | |
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions | |
) | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) | |
def forward( | |
self, | |
timestep: torch.Tensor, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
batch_size: Optional[int] = None, | |
hidden_dtype: Optional[torch.dtype] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
# No modulation happening here. | |
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) | |
return self.linear(self.silu(embedded_timestep)), embedded_timestep | |
class AdaLayerNormShift(nn.Module): | |
r""" | |
Norm layer modified to incorporate timestep embeddings. | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
num_embeddings (`int`): The size of the embeddings dictionary. | |
""" | |
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): | |
super().__init__() | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(embedding_dim, embedding_dim) | |
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) | |
x = self.norm(x) + shift.unsqueeze(dim=1) | |
return x | |
class EasyAnimateLayerNormZero(nn.Module): | |
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py | |
# Add fp32 layer norm | |
def __init__( | |
self, | |
conditioning_dim: int, | |
embedding_dim: int, | |
elementwise_affine: bool = True, | |
eps: float = 1e-5, | |
bias: bool = True, | |
norm_type: str = "fp32_layer_norm", | |
) -> None: | |
super().__init__() | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) | |
if norm_type == "layer_norm": | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) | |
elif norm_type == "fp32_layer_norm": | |
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) | |
else: | |
raise ValueError( | |
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." | |
) | |
def forward( | |
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) | |
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] | |
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] | |
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] |