# Copyright 2024 Black Forest Labs, The HuggingFace Team, The InstantX Team and The MeissonFlow Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import FeedForward, BasicTransformerBlock, SkipFFTransformerBlock from diffusers.models.attention_processor import ( Attention, AttentionProcessor, FluxAttnProcessor2_0, # FusedFluxAttnProcessor2_0, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, GlobalResponseNorm, RMSNorm from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,TimestepEmbedding, get_timestep_embedding #,FluxPosEmbed from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.resnet import Downsample2D, Upsample2D from typing import List logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_3d_rotary_pos_embed( embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ RoPE for video tokens with 3D structure. Args: embed_dim: (`int`): The embedding dimension size, corresponding to hidden_size_head. crops_coords (`Tuple[int]`): The top-left and bottom-right coordinates of the crop. grid_size (`Tuple[int]`): The grid size of the spatial positional embedding (height, width). temporal_size (`int`): The size of the temporal dimension. theta (`float`): Scaling factor for frequency computation. use_real (`bool`): If True, return real part and imaginary part separately. Otherwise, return complex numbers. Returns: `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. """ start, stop = crops_coords grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) # Compute dimensions for each axis dim_t = embed_dim // 4 dim_h = embed_dim // 8 * 3 dim_w = embed_dim // 8 * 3 # Temporal frequencies freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) grid_t = torch.from_numpy(grid_t).float() freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) freqs_t = freqs_t.repeat_interleave(2, dim=-1) # Spatial frequencies for height and width freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) grid_h = torch.from_numpy(grid_h).float() grid_w = torch.from_numpy(grid_w).float() freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) freqs_h = freqs_h.repeat_interleave(2, dim=-1) freqs_w = freqs_w.repeat_interleave(2, dim=-1) # Broadcast and concatenate tensors along specified dimension def broadcast(tensors, dim=-1): num_tensors = len(tensors) shape_lens = {len(t.shape) for t in tensors} assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*(list(t.shape) for t in tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all( [*(len(set(t[1])) <= 2 for t in expandable_dims)] ), "invalid dimensions for broadcastable concatenation" max_dims = [(t[0], max(t[1])) for t in expandable_dims] expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] return torch.cat(tensors, dim=dim) freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) t, h, w, d = freqs.shape freqs = freqs.view(t * h * w, d) # Generate sine and cosine components sin = freqs.sin() cos = freqs.cos() if use_real: return cos, sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. Args: embed_dim: (`int`): The embedding dimension size crops_coords (`Tuple[int]`) The top-left and bottom-right coordinates of the crop. grid_size (`Tuple[int]`): The grid size of the positional embedding. use_real (`bool`): If True, return real part and imaginary part separately. Otherwise, return complex numbers. Returns: `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. """ start, stop = crops_coords grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) # [2, W, H] grid = grid.reshape([2, 1, *grid.shape[1:]]) pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) return pos_embed def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): assert embed_dim % 4 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_rotary_pos_embed( embed_dim // 2, grid[0].reshape(-1), use_real=use_real ) # (H*W, D/2) if use_real else (H*W, D/4) emb_w = get_1d_rotary_pos_embed( embed_dim // 2, grid[1].reshape(-1), use_real=use_real ) # (H*W, D/2) if use_real else (H*W, D/4) if use_real: cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) return cos, sin else: emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) return emb def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): assert embed_dim % 4 == 0 emb_h = get_1d_rotary_pos_embed( embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor ) # (H, D/4) emb_w = get_1d_rotary_pos_embed( embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor ) # (W, D/4) emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) return emb def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (`int`): Dimension of the frequency tensor. pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar theta (`float`, *optional*, defaults to 10000.0): Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. linear_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the context extrapolation. Defaults to 1.0. ntk_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): the dtype of the frequency tensor. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ assert dim % 2 == 0 if isinstance(pos, int): pos = np.arange(pos) theta = theta * ntk_factor freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] return freqs_cis class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() self.theta = theta self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] pos = ids.squeeze().float().cpu().numpy() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype ) cos_out.append(cos) sin_out.append(sin) freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) return freqs_cos, freqs_sin class FusedFluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` # `context` projections. if encoder_hidden_states is not None: encoder_qkv = attn.to_added_qkv(encoder_hidden_states) split_size = encoder_qkv.shape[-1] // 3 ( encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj, ) = torch.split(encoder_qkv, split_size, dim=-1) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from .embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: return hidden_states @maybe_allow_in_graph class SingleTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, processor=processor, qk_norm="rms_norm", eps=1e-6, pre_only=True, ) def forward( self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, ): residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) gate = gate.unsqueeze(1) hidden_states = gate * self.proj_out(hidden_states) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) return hidden_states @maybe_allow_in_graph class TransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): super().__init__() self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) if hasattr(F, "scaled_dot_product_attention"): processor = FluxAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) self.attn = Attention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, processor=processor, qk_norm=qk_norm, eps=eps, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") # self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") # self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, ) # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output # Process attention outputs for the `encoder_hidden_states`. context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states class UVit2DConvEmbed(nn.Module): def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): super().__init__() self.embeddings = nn.Embedding(vocab_size, in_channels) self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) def forward(self, input_ids): embeddings = self.embeddings(input_ids) embeddings = self.layer_norm(embeddings) embeddings = embeddings.permute(0, 3, 1, 2) embeddings = self.conv(embeddings) return embeddings class ConvMlmLayer(nn.Module): def __init__( self, block_out_channels: int, in_channels: int, use_bias: bool, ln_elementwise_affine: bool, layer_norm_eps: float, codebook_size: int, ): super().__init__() self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) def forward(self, hidden_states): hidden_states = self.conv1(hidden_states) hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) logits = self.conv2(hidden_states) return logits class SwiGLU(nn.Module): r""" A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU` but uses SiLU / Swish instead of GeLU. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) self.activation = nn.SiLU() def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states, gate = hidden_states.chunk(2, dim=-1) return hidden_states * self.activation(gate) class ConvNextBlock(nn.Module): def __init__( self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 ): super().__init__() self.depthwise = nn.Conv2d( channels, channels, kernel_size=3, padding=1, groups=channels, bias=use_bias, ) self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) self.channelwise_act = nn.GELU() self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) self.channelwise_dropout = nn.Dropout(hidden_dropout) self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) def forward(self, x, cond_embeds): x_res = x x = self.depthwise(x) x = x.permute(0, 2, 3, 1) x = self.norm(x) x = self.channelwise_linear_1(x) x = self.channelwise_act(x) x = self.channelwise_norm(x) x = self.channelwise_linear_2(x) x = self.channelwise_dropout(x) x = x.permute(0, 3, 1, 2) x = x + x_res scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] return x class Simple_UVitBlock(nn.Module): def __init__( self, channels, ln_elementwise_affine, layer_norm_eps, use_bias, downsample: bool, upsample: bool, ): super().__init__() if downsample: self.downsample = Downsample2D( channels, use_conv=True, padding=0, name="Conv2d_0", kernel_size=2, norm_type="rms_norm", eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine, bias=use_bias, ) else: self.downsample = None if upsample: self.upsample = Upsample2D( channels, use_conv_transpose=True, kernel_size=2, padding=0, name="conv", norm_type="rms_norm", eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine, bias=use_bias, interpolate=False, ) else: self.upsample = None def forward(self, x): # print("before,", x.shape) if self.downsample is not None: # print('downsample') x = self.downsample(x) if self.upsample is not None: # print('upsample') x = self.upsample(x) # print("after,", x.shape) return x class UVitBlock(nn.Module): def __init__( self, channels, num_res_blocks: int, hidden_size, hidden_dropout, ln_elementwise_affine, layer_norm_eps, use_bias, block_num_heads, attention_dropout, downsample: bool, upsample: bool, ): super().__init__() if downsample: self.downsample = Downsample2D( channels, use_conv=True, padding=0, name="Conv2d_0", kernel_size=2, norm_type="rms_norm", eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine, bias=use_bias, ) else: self.downsample = None self.res_blocks = nn.ModuleList( [ ConvNextBlock( channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, ) for i in range(num_res_blocks) ] ) self.attention_blocks = nn.ModuleList( [ SkipFFTransformerBlock( channels, block_num_heads, channels // block_num_heads, hidden_size, use_bias, attention_dropout, channels, attention_bias=use_bias, attention_out_bias=use_bias, ) for _ in range(num_res_blocks) ] ) if upsample: self.upsample = Upsample2D( channels, use_conv_transpose=True, kernel_size=2, padding=0, name="conv", norm_type="rms_norm", eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine, bias=use_bias, interpolate=False, ) else: self.upsample = None def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): if self.downsample is not None: x = self.downsample(x) for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): x = res_block(x, pooled_text_emb) batch_size, channels, height, width = x.shape x = x.view(batch_size, channels, height * width).permute(0, 2, 1) x = attention_block( x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs ) x = x.permute(0, 2, 1).view(batch_size, channels, height, width) if self.upsample is not None: x = self.upsample(x) return x class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ The Transformer model introduced in Flux. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Parameters: patch_size (`int`): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. """ _supports_gradient_checkpointing = False #True # Due to NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph. # Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674. _no_split_modules = ["TransformerBlock", "SingleTransformerBlock"] @register_to_config def __init__( self, patch_size: int = 1, in_channels: int = 64, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, # unused in our implementation axes_dims_rope: Tuple[int] = (16, 56, 56), vocab_size: int = 8256, codebook_size: int = 8192, downsample: bool = False, upsample: bool = False, ): super().__init__() self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) self.time_text_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ TransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ SingleTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_single_layers) ] ) self.gradient_checkpointing = False in_channels_embed = self.inner_dim ln_elementwise_affine = True layer_norm_eps = 1e-06 use_bias = False micro_cond_embed_dim = 1280 self.embed = UVit2DConvEmbed( in_channels_embed, self.inner_dim, self.config.vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias ) self.mlm_layer = ConvMlmLayer( self.inner_dim, in_channels_embed, use_bias, ln_elementwise_affine, layer_norm_eps, self.config.codebook_size ) self.cond_embed = TimestepEmbedding( micro_cond_embed_dim + self.config.pooled_projection_dim, self.inner_dim, sample_proj_bias=use_bias ) self.encoder_proj_layer_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine) self.project_to_hidden_norm = RMSNorm(in_channels_embed, layer_norm_eps, ln_elementwise_affine) self.project_to_hidden = nn.Linear(in_channels_embed, self.inner_dim, bias=use_bias) self.project_from_hidden_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine) self.project_from_hidden = nn.Linear(self.inner_dim, in_channels_embed, bias=use_bias) self.down_block = Simple_UVitBlock( self.inner_dim, ln_elementwise_affine, layer_norm_eps, use_bias, downsample, False, ) self.up_block = Simple_UVitBlock( self.inner_dim, #block_out_channels, ln_elementwise_affine, layer_norm_eps, use_bias, False, upsample=upsample, ) # self.fuse_qkv_projections() @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedFluxAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples= None, controlnet_single_block_samples=None, return_dict: bool = True, micro_conds: torch.Tensor = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ micro_cond_encode_dim = 256 # same as self.config.micro_cond_encode_dim = 256 from amused micro_cond_embeds = get_timestep_embedding( micro_conds.flatten(), micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 ) micro_cond_embeds = micro_cond_embeds.reshape((hidden_states.shape[0], -1)) pooled_projections = torch.cat([pooled_projections, micro_cond_embeds], dim=1) pooled_projections = pooled_projections.to(dtype=self.dtype) pooled_projections = self.cond_embed(pooled_projections).to(encoder_hidden_states.dtype) hidden_states = self.embed(hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) hidden_states = self.down_block(hidden_states) batch_size, channels, height, width = hidden_states.shape hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) hidden_states = self.project_to_hidden_norm(hidden_states) hidden_states = self.project_to_hidden(hidden_states) if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) if txt_ids.ndim == 3: logger.warning( "Passing `txt_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) txt_ids = txt_ids[0] if img_ids.ndim == 3: logger.warning( "Passing `img_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) # controlnet residual if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) # controlnet residual if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_samples[index_block // interval_control] ) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.project_from_hidden_norm(hidden_states) hidden_states = self.project_from_hidden(hidden_states) hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) hidden_states = self.up_block(hidden_states) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) output = self.mlm_layer(hidden_states) # self.unfuse_qkv_projections() if not return_dict: return (output,) return output