import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.normalization import FP32LayerNorm, RMSNorm from typing import Callable, List, Optional, Tuple, Union import math import numpy as np from PIL import Image class IPAFluxAttnProcessor2_0(nn.Module): """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): super().__init__() self.hidden_size = hidden_size # 3072 self.cross_attention_dim = cross_attention_dim # 4096 self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False) #self.norm_added_v = RMSNorm(128, eps=1e-5, elementwise_affine=False) def __call__( self, attn, hidden_states: torch.FloatTensor, image_emb: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # torch.Size([1, 24, 4800, 128]) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) if image_emb is not None: # `ip-adapter` projections ip_hidden_states = image_emb ip_hidden_states_key_proj = self.to_k_ip(ip_hidden_states) ip_hidden_states_value_proj = self.to_v_ip(ip_hidden_states) ip_hidden_states_key_proj = ip_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) ip_hidden_states_value_proj = ip_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) ip_hidden_states_key_proj = self.norm_added_k(ip_hidden_states_key_proj) #ip_hidden_states_valye_proj = self.norm_added_v(ip_hidden_states_value_proj) ip_hidden_states = F.scaled_dot_product_attention(query, ip_hidden_states_key_proj, ip_hidden_states_value_proj, dropout_p=0.0, is_causal=False) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) # (512+3840,128) if image_rotary_emb is not None: from diffusers.models.embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) if image_emb is not None: hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: if image_emb is not None: hidden_states = hidden_states + self.scale * ip_hidden_states return hidden_states