multimodalart HF staff commited on
Commit
ae20f61
1 Parent(s): 757be97

Upload 3 files

Browse files
Files changed (3) hide show
  1. models/attention.py +1245 -0
  2. models/resampler.py +304 -0
  3. models/transformer_sd3.py +375 -0
models/attention.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
+ from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ num_attention_heads: int,
107
+ attention_head_dim: int,
108
+ context_pre_only: bool = False,
109
+ qk_norm: Optional[str] = None,
110
+ use_dual_attention: bool = False,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.use_dual_attention = use_dual_attention
115
+ self.context_pre_only = context_pre_only
116
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
117
+
118
+ if use_dual_attention:
119
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
120
+ else:
121
+ self.norm1 = AdaLayerNormZero(dim)
122
+
123
+ if context_norm_type == "ada_norm_continous":
124
+ self.norm1_context = AdaLayerNormContinuous(
125
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
126
+ )
127
+ elif context_norm_type == "ada_norm_zero":
128
+ self.norm1_context = AdaLayerNormZero(dim)
129
+ else:
130
+ raise ValueError(
131
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
132
+ )
133
+
134
+ if hasattr(F, "scaled_dot_product_attention"):
135
+ processor = JointAttnProcessor2_0()
136
+ else:
137
+ raise ValueError(
138
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
139
+ )
140
+
141
+ self.attn = Attention(
142
+ query_dim=dim,
143
+ cross_attention_dim=None,
144
+ added_kv_proj_dim=dim,
145
+ dim_head=attention_head_dim,
146
+ heads=num_attention_heads,
147
+ out_dim=dim,
148
+ context_pre_only=context_pre_only,
149
+ bias=True,
150
+ processor=processor,
151
+ qk_norm=qk_norm,
152
+ eps=1e-6,
153
+ )
154
+
155
+ if use_dual_attention:
156
+ self.attn2 = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ out_dim=dim,
162
+ bias=True,
163
+ processor=processor,
164
+ qk_norm=qk_norm,
165
+ eps=1e-6,
166
+ )
167
+ else:
168
+ self.attn2 = None
169
+
170
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
171
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
172
+
173
+ if not context_pre_only:
174
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
175
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
176
+ else:
177
+ self.norm2_context = None
178
+ self.ff_context = None
179
+
180
+ # let chunk size default to None
181
+ self._chunk_size = None
182
+ self._chunk_dim = 0
183
+
184
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
185
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
186
+ # Sets chunk feed-forward
187
+ self._chunk_size = chunk_size
188
+ self._chunk_dim = dim
189
+
190
+ def forward(
191
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
192
+ joint_attention_kwargs=None,
193
+ ):
194
+ if self.use_dual_attention:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
196
+ hidden_states, emb=temb
197
+ )
198
+ else:
199
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
200
+
201
+ if self.context_pre_only:
202
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
203
+ else:
204
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
205
+ encoder_hidden_states, emb=temb
206
+ )
207
+
208
+ # Attention.
209
+ attn_output, context_attn_output = self.attn(
210
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
211
+ **({} if joint_attention_kwargs is None else joint_attention_kwargs),
212
+ )
213
+
214
+ # Process attention outputs for the `hidden_states`.
215
+ attn_output = gate_msa.unsqueeze(1) * attn_output
216
+ hidden_states = hidden_states + attn_output
217
+
218
+ if self.use_dual_attention:
219
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),)
220
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
221
+ hidden_states = hidden_states + attn_output2
222
+
223
+ norm_hidden_states = self.norm2(hidden_states)
224
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
225
+ if self._chunk_size is not None:
226
+ # "feed_forward_chunk_size" can be used to save memory
227
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
228
+ else:
229
+ ff_output = self.ff(norm_hidden_states)
230
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
231
+
232
+ hidden_states = hidden_states + ff_output
233
+
234
+ # Process attention outputs for the `encoder_hidden_states`.
235
+ if self.context_pre_only:
236
+ encoder_hidden_states = None
237
+ else:
238
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
239
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
240
+
241
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
242
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
243
+ if self._chunk_size is not None:
244
+ # "feed_forward_chunk_size" can be used to save memory
245
+ context_ff_output = _chunked_feed_forward(
246
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
247
+ )
248
+ else:
249
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
250
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
251
+
252
+ return encoder_hidden_states, hidden_states
253
+
254
+
255
+ @maybe_allow_in_graph
256
+ class BasicTransformerBlock(nn.Module):
257
+ r"""
258
+ A basic Transformer block.
259
+
260
+ Parameters:
261
+ dim (`int`): The number of channels in the input and output.
262
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
263
+ attention_head_dim (`int`): The number of channels in each head.
264
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
265
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
266
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
267
+ num_embeds_ada_norm (:
268
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
269
+ attention_bias (:
270
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
271
+ only_cross_attention (`bool`, *optional*):
272
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
273
+ double_self_attention (`bool`, *optional*):
274
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
275
+ upcast_attention (`bool`, *optional*):
276
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
277
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
278
+ Whether to use learnable elementwise affine parameters for normalization.
279
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
280
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
281
+ final_dropout (`bool` *optional*, defaults to False):
282
+ Whether to apply a final dropout after the last feed-forward layer.
283
+ attention_type (`str`, *optional*, defaults to `"default"`):
284
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
285
+ positional_embeddings (`str`, *optional*, defaults to `None`):
286
+ The type of positional embeddings to apply to.
287
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
288
+ The maximum number of positional embeddings to apply.
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ dim: int,
294
+ num_attention_heads: int,
295
+ attention_head_dim: int,
296
+ dropout=0.0,
297
+ cross_attention_dim: Optional[int] = None,
298
+ activation_fn: str = "geglu",
299
+ num_embeds_ada_norm: Optional[int] = None,
300
+ attention_bias: bool = False,
301
+ only_cross_attention: bool = False,
302
+ double_self_attention: bool = False,
303
+ upcast_attention: bool = False,
304
+ norm_elementwise_affine: bool = True,
305
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
306
+ norm_eps: float = 1e-5,
307
+ final_dropout: bool = False,
308
+ attention_type: str = "default",
309
+ positional_embeddings: Optional[str] = None,
310
+ num_positional_embeddings: Optional[int] = None,
311
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
312
+ ada_norm_bias: Optional[int] = None,
313
+ ff_inner_dim: Optional[int] = None,
314
+ ff_bias: bool = True,
315
+ attention_out_bias: bool = True,
316
+ ):
317
+ super().__init__()
318
+ self.dim = dim
319
+ self.num_attention_heads = num_attention_heads
320
+ self.attention_head_dim = attention_head_dim
321
+ self.dropout = dropout
322
+ self.cross_attention_dim = cross_attention_dim
323
+ self.activation_fn = activation_fn
324
+ self.attention_bias = attention_bias
325
+ self.double_self_attention = double_self_attention
326
+ self.norm_elementwise_affine = norm_elementwise_affine
327
+ self.positional_embeddings = positional_embeddings
328
+ self.num_positional_embeddings = num_positional_embeddings
329
+ self.only_cross_attention = only_cross_attention
330
+
331
+ # We keep these boolean flags for backward-compatibility.
332
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
333
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
334
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
335
+ self.use_layer_norm = norm_type == "layer_norm"
336
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
337
+
338
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
339
+ raise ValueError(
340
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
341
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
342
+ )
343
+
344
+ self.norm_type = norm_type
345
+ self.num_embeds_ada_norm = num_embeds_ada_norm
346
+
347
+ if positional_embeddings and (num_positional_embeddings is None):
348
+ raise ValueError(
349
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
350
+ )
351
+
352
+ if positional_embeddings == "sinusoidal":
353
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
354
+ else:
355
+ self.pos_embed = None
356
+
357
+ # Define 3 blocks. Each block has its own normalization layer.
358
+ # 1. Self-Attn
359
+ if norm_type == "ada_norm":
360
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
361
+ elif norm_type == "ada_norm_zero":
362
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
363
+ elif norm_type == "ada_norm_continuous":
364
+ self.norm1 = AdaLayerNormContinuous(
365
+ dim,
366
+ ada_norm_continous_conditioning_embedding_dim,
367
+ norm_elementwise_affine,
368
+ norm_eps,
369
+ ada_norm_bias,
370
+ "rms_norm",
371
+ )
372
+ else:
373
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
374
+
375
+ self.attn1 = Attention(
376
+ query_dim=dim,
377
+ heads=num_attention_heads,
378
+ dim_head=attention_head_dim,
379
+ dropout=dropout,
380
+ bias=attention_bias,
381
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
382
+ upcast_attention=upcast_attention,
383
+ out_bias=attention_out_bias,
384
+ )
385
+
386
+ # 2. Cross-Attn
387
+ if cross_attention_dim is not None or double_self_attention:
388
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
389
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
390
+ # the second cross attention block.
391
+ if norm_type == "ada_norm":
392
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
393
+ elif norm_type == "ada_norm_continuous":
394
+ self.norm2 = AdaLayerNormContinuous(
395
+ dim,
396
+ ada_norm_continous_conditioning_embedding_dim,
397
+ norm_elementwise_affine,
398
+ norm_eps,
399
+ ada_norm_bias,
400
+ "rms_norm",
401
+ )
402
+ else:
403
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
404
+
405
+ self.attn2 = Attention(
406
+ query_dim=dim,
407
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
408
+ heads=num_attention_heads,
409
+ dim_head=attention_head_dim,
410
+ dropout=dropout,
411
+ bias=attention_bias,
412
+ upcast_attention=upcast_attention,
413
+ out_bias=attention_out_bias,
414
+ ) # is self-attn if encoder_hidden_states is none
415
+ else:
416
+ if norm_type == "ada_norm_single": # For Latte
417
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
418
+ else:
419
+ self.norm2 = None
420
+ self.attn2 = None
421
+
422
+ # 3. Feed-forward
423
+ if norm_type == "ada_norm_continuous":
424
+ self.norm3 = AdaLayerNormContinuous(
425
+ dim,
426
+ ada_norm_continous_conditioning_embedding_dim,
427
+ norm_elementwise_affine,
428
+ norm_eps,
429
+ ada_norm_bias,
430
+ "layer_norm",
431
+ )
432
+
433
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
434
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
435
+ elif norm_type == "layer_norm_i2vgen":
436
+ self.norm3 = None
437
+
438
+ self.ff = FeedForward(
439
+ dim,
440
+ dropout=dropout,
441
+ activation_fn=activation_fn,
442
+ final_dropout=final_dropout,
443
+ inner_dim=ff_inner_dim,
444
+ bias=ff_bias,
445
+ )
446
+
447
+ # 4. Fuser
448
+ if attention_type == "gated" or attention_type == "gated-text-image":
449
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
450
+
451
+ # 5. Scale-shift for PixArt-Alpha.
452
+ if norm_type == "ada_norm_single":
453
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
454
+
455
+ # let chunk size default to None
456
+ self._chunk_size = None
457
+ self._chunk_dim = 0
458
+
459
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
460
+ # Sets chunk feed-forward
461
+ self._chunk_size = chunk_size
462
+ self._chunk_dim = dim
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: torch.Tensor,
467
+ attention_mask: Optional[torch.Tensor] = None,
468
+ encoder_hidden_states: Optional[torch.Tensor] = None,
469
+ encoder_attention_mask: Optional[torch.Tensor] = None,
470
+ timestep: Optional[torch.LongTensor] = None,
471
+ cross_attention_kwargs: Dict[str, Any] = None,
472
+ class_labels: Optional[torch.LongTensor] = None,
473
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
474
+ ) -> torch.Tensor:
475
+ if cross_attention_kwargs is not None:
476
+ if cross_attention_kwargs.get("scale", None) is not None:
477
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
478
+
479
+ # Notice that normalization is always applied before the real computation in the following blocks.
480
+ # 0. Self-Attention
481
+ batch_size = hidden_states.shape[0]
482
+
483
+ if self.norm_type == "ada_norm":
484
+ norm_hidden_states = self.norm1(hidden_states, timestep)
485
+ elif self.norm_type == "ada_norm_zero":
486
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
487
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
488
+ )
489
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
490
+ norm_hidden_states = self.norm1(hidden_states)
491
+ elif self.norm_type == "ada_norm_continuous":
492
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
493
+ elif self.norm_type == "ada_norm_single":
494
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
495
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
496
+ ).chunk(6, dim=1)
497
+ norm_hidden_states = self.norm1(hidden_states)
498
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
499
+ else:
500
+ raise ValueError("Incorrect norm used")
501
+
502
+ if self.pos_embed is not None:
503
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
504
+
505
+ # 1. Prepare GLIGEN inputs
506
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
507
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
508
+
509
+ attn_output = self.attn1(
510
+ norm_hidden_states,
511
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
512
+ attention_mask=attention_mask,
513
+ **cross_attention_kwargs,
514
+ )
515
+
516
+ if self.norm_type == "ada_norm_zero":
517
+ attn_output = gate_msa.unsqueeze(1) * attn_output
518
+ elif self.norm_type == "ada_norm_single":
519
+ attn_output = gate_msa * attn_output
520
+
521
+ hidden_states = attn_output + hidden_states
522
+ if hidden_states.ndim == 4:
523
+ hidden_states = hidden_states.squeeze(1)
524
+
525
+ # 1.2 GLIGEN Control
526
+ if gligen_kwargs is not None:
527
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
528
+
529
+ # 3. Cross-Attention
530
+ if self.attn2 is not None:
531
+ if self.norm_type == "ada_norm":
532
+ norm_hidden_states = self.norm2(hidden_states, timestep)
533
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
534
+ norm_hidden_states = self.norm2(hidden_states)
535
+ elif self.norm_type == "ada_norm_single":
536
+ # For PixArt norm2 isn't applied here:
537
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
538
+ norm_hidden_states = hidden_states
539
+ elif self.norm_type == "ada_norm_continuous":
540
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
541
+ else:
542
+ raise ValueError("Incorrect norm")
543
+
544
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
545
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
546
+
547
+ attn_output = self.attn2(
548
+ norm_hidden_states,
549
+ encoder_hidden_states=encoder_hidden_states,
550
+ attention_mask=encoder_attention_mask,
551
+ **cross_attention_kwargs,
552
+ )
553
+ hidden_states = attn_output + hidden_states
554
+
555
+ # 4. Feed-forward
556
+ # i2vgen doesn't have this norm 🤷‍♂️
557
+ if self.norm_type == "ada_norm_continuous":
558
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
559
+ elif not self.norm_type == "ada_norm_single":
560
+ norm_hidden_states = self.norm3(hidden_states)
561
+
562
+ if self.norm_type == "ada_norm_zero":
563
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
564
+
565
+ if self.norm_type == "ada_norm_single":
566
+ norm_hidden_states = self.norm2(hidden_states)
567
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
568
+
569
+ if self._chunk_size is not None:
570
+ # "feed_forward_chunk_size" can be used to save memory
571
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
572
+ else:
573
+ ff_output = self.ff(norm_hidden_states)
574
+
575
+ if self.norm_type == "ada_norm_zero":
576
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
577
+ elif self.norm_type == "ada_norm_single":
578
+ ff_output = gate_mlp * ff_output
579
+
580
+ hidden_states = ff_output + hidden_states
581
+ if hidden_states.ndim == 4:
582
+ hidden_states = hidden_states.squeeze(1)
583
+
584
+ return hidden_states
585
+
586
+
587
+ class LuminaFeedForward(nn.Module):
588
+ r"""
589
+ A feed-forward layer.
590
+
591
+ Parameters:
592
+ hidden_size (`int`):
593
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
594
+ hidden representations.
595
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
596
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
597
+ of this value.
598
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
599
+ dimension. Defaults to None.
600
+ """
601
+
602
+ def __init__(
603
+ self,
604
+ dim: int,
605
+ inner_dim: int,
606
+ multiple_of: Optional[int] = 256,
607
+ ffn_dim_multiplier: Optional[float] = None,
608
+ ):
609
+ super().__init__()
610
+ inner_dim = int(2 * inner_dim / 3)
611
+ # custom hidden_size factor multiplier
612
+ if ffn_dim_multiplier is not None:
613
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
614
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
615
+
616
+ self.linear_1 = nn.Linear(
617
+ dim,
618
+ inner_dim,
619
+ bias=False,
620
+ )
621
+ self.linear_2 = nn.Linear(
622
+ inner_dim,
623
+ dim,
624
+ bias=False,
625
+ )
626
+ self.linear_3 = nn.Linear(
627
+ dim,
628
+ inner_dim,
629
+ bias=False,
630
+ )
631
+ self.silu = FP32SiLU()
632
+
633
+ def forward(self, x):
634
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
635
+
636
+
637
+ @maybe_allow_in_graph
638
+ class TemporalBasicTransformerBlock(nn.Module):
639
+ r"""
640
+ A basic Transformer block for video like data.
641
+
642
+ Parameters:
643
+ dim (`int`): The number of channels in the input and output.
644
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
645
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
646
+ attention_head_dim (`int`): The number of channels in each head.
647
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
648
+ """
649
+
650
+ def __init__(
651
+ self,
652
+ dim: int,
653
+ time_mix_inner_dim: int,
654
+ num_attention_heads: int,
655
+ attention_head_dim: int,
656
+ cross_attention_dim: Optional[int] = None,
657
+ ):
658
+ super().__init__()
659
+ self.is_res = dim == time_mix_inner_dim
660
+
661
+ self.norm_in = nn.LayerNorm(dim)
662
+
663
+ # Define 3 blocks. Each block has its own normalization layer.
664
+ # 1. Self-Attn
665
+ self.ff_in = FeedForward(
666
+ dim,
667
+ dim_out=time_mix_inner_dim,
668
+ activation_fn="geglu",
669
+ )
670
+
671
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
672
+ self.attn1 = Attention(
673
+ query_dim=time_mix_inner_dim,
674
+ heads=num_attention_heads,
675
+ dim_head=attention_head_dim,
676
+ cross_attention_dim=None,
677
+ )
678
+
679
+ # 2. Cross-Attn
680
+ if cross_attention_dim is not None:
681
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
682
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
683
+ # the second cross attention block.
684
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
685
+ self.attn2 = Attention(
686
+ query_dim=time_mix_inner_dim,
687
+ cross_attention_dim=cross_attention_dim,
688
+ heads=num_attention_heads,
689
+ dim_head=attention_head_dim,
690
+ ) # is self-attn if encoder_hidden_states is none
691
+ else:
692
+ self.norm2 = None
693
+ self.attn2 = None
694
+
695
+ # 3. Feed-forward
696
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
697
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
698
+
699
+ # let chunk size default to None
700
+ self._chunk_size = None
701
+ self._chunk_dim = None
702
+
703
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
704
+ # Sets chunk feed-forward
705
+ self._chunk_size = chunk_size
706
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
707
+ self._chunk_dim = 1
708
+
709
+ def forward(
710
+ self,
711
+ hidden_states: torch.Tensor,
712
+ num_frames: int,
713
+ encoder_hidden_states: Optional[torch.Tensor] = None,
714
+ ) -> torch.Tensor:
715
+ # Notice that normalization is always applied before the real computation in the following blocks.
716
+ # 0. Self-Attention
717
+ batch_size = hidden_states.shape[0]
718
+
719
+ batch_frames, seq_length, channels = hidden_states.shape
720
+ batch_size = batch_frames // num_frames
721
+
722
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
723
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
724
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
725
+
726
+ residual = hidden_states
727
+ hidden_states = self.norm_in(hidden_states)
728
+
729
+ if self._chunk_size is not None:
730
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
731
+ else:
732
+ hidden_states = self.ff_in(hidden_states)
733
+
734
+ if self.is_res:
735
+ hidden_states = hidden_states + residual
736
+
737
+ norm_hidden_states = self.norm1(hidden_states)
738
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
739
+ hidden_states = attn_output + hidden_states
740
+
741
+ # 3. Cross-Attention
742
+ if self.attn2 is not None:
743
+ norm_hidden_states = self.norm2(hidden_states)
744
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
745
+ hidden_states = attn_output + hidden_states
746
+
747
+ # 4. Feed-forward
748
+ norm_hidden_states = self.norm3(hidden_states)
749
+
750
+ if self._chunk_size is not None:
751
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
752
+ else:
753
+ ff_output = self.ff(norm_hidden_states)
754
+
755
+ if self.is_res:
756
+ hidden_states = ff_output + hidden_states
757
+ else:
758
+ hidden_states = ff_output
759
+
760
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
761
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
762
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
763
+
764
+ return hidden_states
765
+
766
+
767
+ class SkipFFTransformerBlock(nn.Module):
768
+ def __init__(
769
+ self,
770
+ dim: int,
771
+ num_attention_heads: int,
772
+ attention_head_dim: int,
773
+ kv_input_dim: int,
774
+ kv_input_dim_proj_use_bias: bool,
775
+ dropout=0.0,
776
+ cross_attention_dim: Optional[int] = None,
777
+ attention_bias: bool = False,
778
+ attention_out_bias: bool = True,
779
+ ):
780
+ super().__init__()
781
+ if kv_input_dim != dim:
782
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
783
+ else:
784
+ self.kv_mapper = None
785
+
786
+ self.norm1 = RMSNorm(dim, 1e-06)
787
+
788
+ self.attn1 = Attention(
789
+ query_dim=dim,
790
+ heads=num_attention_heads,
791
+ dim_head=attention_head_dim,
792
+ dropout=dropout,
793
+ bias=attention_bias,
794
+ cross_attention_dim=cross_attention_dim,
795
+ out_bias=attention_out_bias,
796
+ )
797
+
798
+ self.norm2 = RMSNorm(dim, 1e-06)
799
+
800
+ self.attn2 = Attention(
801
+ query_dim=dim,
802
+ cross_attention_dim=cross_attention_dim,
803
+ heads=num_attention_heads,
804
+ dim_head=attention_head_dim,
805
+ dropout=dropout,
806
+ bias=attention_bias,
807
+ out_bias=attention_out_bias,
808
+ )
809
+
810
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
811
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
812
+
813
+ if self.kv_mapper is not None:
814
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
815
+
816
+ norm_hidden_states = self.norm1(hidden_states)
817
+
818
+ attn_output = self.attn1(
819
+ norm_hidden_states,
820
+ encoder_hidden_states=encoder_hidden_states,
821
+ **cross_attention_kwargs,
822
+ )
823
+
824
+ hidden_states = attn_output + hidden_states
825
+
826
+ norm_hidden_states = self.norm2(hidden_states)
827
+
828
+ attn_output = self.attn2(
829
+ norm_hidden_states,
830
+ encoder_hidden_states=encoder_hidden_states,
831
+ **cross_attention_kwargs,
832
+ )
833
+
834
+ hidden_states = attn_output + hidden_states
835
+
836
+ return hidden_states
837
+
838
+
839
+ @maybe_allow_in_graph
840
+ class FreeNoiseTransformerBlock(nn.Module):
841
+ r"""
842
+ A FreeNoise Transformer block.
843
+
844
+ Parameters:
845
+ dim (`int`):
846
+ The number of channels in the input and output.
847
+ num_attention_heads (`int`):
848
+ The number of heads to use for multi-head attention.
849
+ attention_head_dim (`int`):
850
+ The number of channels in each head.
851
+ dropout (`float`, *optional*, defaults to 0.0):
852
+ The dropout probability to use.
853
+ cross_attention_dim (`int`, *optional*):
854
+ The size of the encoder_hidden_states vector for cross attention.
855
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
856
+ Activation function to be used in feed-forward.
857
+ num_embeds_ada_norm (`int`, *optional*):
858
+ The number of diffusion steps used during training. See `Transformer2DModel`.
859
+ attention_bias (`bool`, defaults to `False`):
860
+ Configure if the attentions should contain a bias parameter.
861
+ only_cross_attention (`bool`, defaults to `False`):
862
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
863
+ double_self_attention (`bool`, defaults to `False`):
864
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
865
+ upcast_attention (`bool`, defaults to `False`):
866
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
867
+ norm_elementwise_affine (`bool`, defaults to `True`):
868
+ Whether to use learnable elementwise affine parameters for normalization.
869
+ norm_type (`str`, defaults to `"layer_norm"`):
870
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
871
+ final_dropout (`bool` defaults to `False`):
872
+ Whether to apply a final dropout after the last feed-forward layer.
873
+ attention_type (`str`, defaults to `"default"`):
874
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
875
+ positional_embeddings (`str`, *optional*):
876
+ The type of positional embeddings to apply to.
877
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
878
+ The maximum number of positional embeddings to apply.
879
+ ff_inner_dim (`int`, *optional*):
880
+ Hidden dimension of feed-forward MLP.
881
+ ff_bias (`bool`, defaults to `True`):
882
+ Whether or not to use bias in feed-forward MLP.
883
+ attention_out_bias (`bool`, defaults to `True`):
884
+ Whether or not to use bias in attention output project layer.
885
+ context_length (`int`, defaults to `16`):
886
+ The maximum number of frames that the FreeNoise block processes at once.
887
+ context_stride (`int`, defaults to `4`):
888
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
889
+ weighting_scheme (`str`, defaults to `"pyramid"`):
890
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
891
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
892
+ used.
893
+ """
894
+
895
+ def __init__(
896
+ self,
897
+ dim: int,
898
+ num_attention_heads: int,
899
+ attention_head_dim: int,
900
+ dropout: float = 0.0,
901
+ cross_attention_dim: Optional[int] = None,
902
+ activation_fn: str = "geglu",
903
+ num_embeds_ada_norm: Optional[int] = None,
904
+ attention_bias: bool = False,
905
+ only_cross_attention: bool = False,
906
+ double_self_attention: bool = False,
907
+ upcast_attention: bool = False,
908
+ norm_elementwise_affine: bool = True,
909
+ norm_type: str = "layer_norm",
910
+ norm_eps: float = 1e-5,
911
+ final_dropout: bool = False,
912
+ positional_embeddings: Optional[str] = None,
913
+ num_positional_embeddings: Optional[int] = None,
914
+ ff_inner_dim: Optional[int] = None,
915
+ ff_bias: bool = True,
916
+ attention_out_bias: bool = True,
917
+ context_length: int = 16,
918
+ context_stride: int = 4,
919
+ weighting_scheme: str = "pyramid",
920
+ ):
921
+ super().__init__()
922
+ self.dim = dim
923
+ self.num_attention_heads = num_attention_heads
924
+ self.attention_head_dim = attention_head_dim
925
+ self.dropout = dropout
926
+ self.cross_attention_dim = cross_attention_dim
927
+ self.activation_fn = activation_fn
928
+ self.attention_bias = attention_bias
929
+ self.double_self_attention = double_self_attention
930
+ self.norm_elementwise_affine = norm_elementwise_affine
931
+ self.positional_embeddings = positional_embeddings
932
+ self.num_positional_embeddings = num_positional_embeddings
933
+ self.only_cross_attention = only_cross_attention
934
+
935
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
936
+
937
+ # We keep these boolean flags for backward-compatibility.
938
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
939
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
940
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
941
+ self.use_layer_norm = norm_type == "layer_norm"
942
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
943
+
944
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
945
+ raise ValueError(
946
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
947
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
948
+ )
949
+
950
+ self.norm_type = norm_type
951
+ self.num_embeds_ada_norm = num_embeds_ada_norm
952
+
953
+ if positional_embeddings and (num_positional_embeddings is None):
954
+ raise ValueError(
955
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
956
+ )
957
+
958
+ if positional_embeddings == "sinusoidal":
959
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
960
+ else:
961
+ self.pos_embed = None
962
+
963
+ # Define 3 blocks. Each block has its own normalization layer.
964
+ # 1. Self-Attn
965
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
966
+
967
+ self.attn1 = Attention(
968
+ query_dim=dim,
969
+ heads=num_attention_heads,
970
+ dim_head=attention_head_dim,
971
+ dropout=dropout,
972
+ bias=attention_bias,
973
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
974
+ upcast_attention=upcast_attention,
975
+ out_bias=attention_out_bias,
976
+ )
977
+
978
+ # 2. Cross-Attn
979
+ if cross_attention_dim is not None or double_self_attention:
980
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
981
+
982
+ self.attn2 = Attention(
983
+ query_dim=dim,
984
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
985
+ heads=num_attention_heads,
986
+ dim_head=attention_head_dim,
987
+ dropout=dropout,
988
+ bias=attention_bias,
989
+ upcast_attention=upcast_attention,
990
+ out_bias=attention_out_bias,
991
+ ) # is self-attn if encoder_hidden_states is none
992
+
993
+ # 3. Feed-forward
994
+ self.ff = FeedForward(
995
+ dim,
996
+ dropout=dropout,
997
+ activation_fn=activation_fn,
998
+ final_dropout=final_dropout,
999
+ inner_dim=ff_inner_dim,
1000
+ bias=ff_bias,
1001
+ )
1002
+
1003
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1004
+
1005
+ # let chunk size default to None
1006
+ self._chunk_size = None
1007
+ self._chunk_dim = 0
1008
+
1009
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1010
+ frame_indices = []
1011
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1012
+ window_start = i
1013
+ window_end = min(num_frames, i + self.context_length)
1014
+ frame_indices.append((window_start, window_end))
1015
+ return frame_indices
1016
+
1017
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1018
+ if weighting_scheme == "flat":
1019
+ weights = [1.0] * num_frames
1020
+
1021
+ elif weighting_scheme == "pyramid":
1022
+ if num_frames % 2 == 0:
1023
+ # num_frames = 4 => [1, 2, 2, 1]
1024
+ mid = num_frames // 2
1025
+ weights = list(range(1, mid + 1))
1026
+ weights = weights + weights[::-1]
1027
+ else:
1028
+ # num_frames = 5 => [1, 2, 3, 2, 1]
1029
+ mid = (num_frames + 1) // 2
1030
+ weights = list(range(1, mid))
1031
+ weights = weights + [mid] + weights[::-1]
1032
+
1033
+ elif weighting_scheme == "delayed_reverse_sawtooth":
1034
+ if num_frames % 2 == 0:
1035
+ # num_frames = 4 => [0.01, 2, 2, 1]
1036
+ mid = num_frames // 2
1037
+ weights = [0.01] * (mid - 1) + [mid]
1038
+ weights = weights + list(range(mid, 0, -1))
1039
+ else:
1040
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1041
+ mid = (num_frames + 1) // 2
1042
+ weights = [0.01] * mid
1043
+ weights = weights + list(range(mid, 0, -1))
1044
+ else:
1045
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1046
+
1047
+ return weights
1048
+
1049
+ def set_free_noise_properties(
1050
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1051
+ ) -> None:
1052
+ self.context_length = context_length
1053
+ self.context_stride = context_stride
1054
+ self.weighting_scheme = weighting_scheme
1055
+
1056
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1057
+ # Sets chunk feed-forward
1058
+ self._chunk_size = chunk_size
1059
+ self._chunk_dim = dim
1060
+
1061
+ def forward(
1062
+ self,
1063
+ hidden_states: torch.Tensor,
1064
+ attention_mask: Optional[torch.Tensor] = None,
1065
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1066
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1067
+ cross_attention_kwargs: Dict[str, Any] = None,
1068
+ *args,
1069
+ **kwargs,
1070
+ ) -> torch.Tensor:
1071
+ if cross_attention_kwargs is not None:
1072
+ if cross_attention_kwargs.get("scale", None) is not None:
1073
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1074
+
1075
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1076
+
1077
+ # hidden_states: [B x H x W, F, C]
1078
+ device = hidden_states.device
1079
+ dtype = hidden_states.dtype
1080
+
1081
+ num_frames = hidden_states.size(1)
1082
+ frame_indices = self._get_frame_indices(num_frames)
1083
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1084
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1085
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1086
+
1087
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1088
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1089
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1090
+ if not is_last_frame_batch_complete:
1091
+ if num_frames < self.context_length:
1092
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1093
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1094
+ frame_indices.append((num_frames - self.context_length, num_frames))
1095
+
1096
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1097
+ accumulated_values = torch.zeros_like(hidden_states)
1098
+
1099
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1100
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1101
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1102
+ # essentially a non-multiple of `context_length`.
1103
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1104
+ weights *= frame_weights
1105
+
1106
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1107
+
1108
+ # Notice that normalization is always applied before the real computation in the following blocks.
1109
+ # 1. Self-Attention
1110
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1111
+
1112
+ if self.pos_embed is not None:
1113
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1114
+
1115
+ attn_output = self.attn1(
1116
+ norm_hidden_states,
1117
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1118
+ attention_mask=attention_mask,
1119
+ **cross_attention_kwargs,
1120
+ )
1121
+
1122
+ hidden_states_chunk = attn_output + hidden_states_chunk
1123
+ if hidden_states_chunk.ndim == 4:
1124
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1125
+
1126
+ # 2. Cross-Attention
1127
+ if self.attn2 is not None:
1128
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1129
+
1130
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1131
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1132
+
1133
+ attn_output = self.attn2(
1134
+ norm_hidden_states,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ attention_mask=encoder_attention_mask,
1137
+ **cross_attention_kwargs,
1138
+ )
1139
+ hidden_states_chunk = attn_output + hidden_states_chunk
1140
+
1141
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1142
+ accumulated_values[:, -last_frame_batch_length:] += (
1143
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1144
+ )
1145
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1146
+ else:
1147
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1148
+ num_times_accumulated[:, frame_start:frame_end] += weights
1149
+
1150
+ # TODO(aryan): Maybe this could be done in a better way.
1151
+ #
1152
+ # Previously, this was:
1153
+ # hidden_states = torch.where(
1154
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1155
+ # )
1156
+ #
1157
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1158
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1159
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1160
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1161
+ hidden_states = torch.cat(
1162
+ [
1163
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1164
+ for accumulated_split, num_times_split in zip(
1165
+ accumulated_values.split(self.context_length, dim=1),
1166
+ num_times_accumulated.split(self.context_length, dim=1),
1167
+ )
1168
+ ],
1169
+ dim=1,
1170
+ ).to(dtype)
1171
+
1172
+ # 3. Feed-forward
1173
+ norm_hidden_states = self.norm3(hidden_states)
1174
+
1175
+ if self._chunk_size is not None:
1176
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1177
+ else:
1178
+ ff_output = self.ff(norm_hidden_states)
1179
+
1180
+ hidden_states = ff_output + hidden_states
1181
+ if hidden_states.ndim == 4:
1182
+ hidden_states = hidden_states.squeeze(1)
1183
+
1184
+ return hidden_states
1185
+
1186
+
1187
+ class FeedForward(nn.Module):
1188
+ r"""
1189
+ A feed-forward layer.
1190
+
1191
+ Parameters:
1192
+ dim (`int`): The number of channels in the input.
1193
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1194
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1195
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1196
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1197
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1198
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1199
+ """
1200
+
1201
+ def __init__(
1202
+ self,
1203
+ dim: int,
1204
+ dim_out: Optional[int] = None,
1205
+ mult: int = 4,
1206
+ dropout: float = 0.0,
1207
+ activation_fn: str = "geglu",
1208
+ final_dropout: bool = False,
1209
+ inner_dim=None,
1210
+ bias: bool = True,
1211
+ ):
1212
+ super().__init__()
1213
+ if inner_dim is None:
1214
+ inner_dim = int(dim * mult)
1215
+ dim_out = dim_out if dim_out is not None else dim
1216
+
1217
+ if activation_fn == "gelu":
1218
+ act_fn = GELU(dim, inner_dim, bias=bias)
1219
+ if activation_fn == "gelu-approximate":
1220
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1221
+ elif activation_fn == "geglu":
1222
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1223
+ elif activation_fn == "geglu-approximate":
1224
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1225
+ elif activation_fn == "swiglu":
1226
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1227
+
1228
+ self.net = nn.ModuleList([])
1229
+ # project in
1230
+ self.net.append(act_fn)
1231
+ # project dropout
1232
+ self.net.append(nn.Dropout(dropout))
1233
+ # project out
1234
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1235
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1236
+ if final_dropout:
1237
+ self.net.append(nn.Dropout(dropout))
1238
+
1239
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1240
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1241
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1242
+ deprecate("scale", "1.0.0", deprecation_message)
1243
+ for module in self.net:
1244
+ hidden_states = module(hidden_states)
1245
+ return hidden_states
models/resampler.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding
8
+
9
+ def get_timestep_embedding(
10
+ timesteps: torch.Tensor,
11
+ embedding_dim: int,
12
+ flip_sin_to_cos: bool = False,
13
+ downscale_freq_shift: float = 1,
14
+ scale: float = 1,
15
+ max_period: int = 10000,
16
+ ):
17
+ """
18
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
19
+
20
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
21
+ These may be fractional.
22
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
23
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
24
+ """
25
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
26
+
27
+ half_dim = embedding_dim // 2
28
+ exponent = -math.log(max_period) * torch.arange(
29
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
30
+ )
31
+ exponent = exponent / (half_dim - downscale_freq_shift)
32
+
33
+ emb = torch.exp(exponent)
34
+ emb = timesteps[:, None].float() * emb[None, :]
35
+
36
+ # scale embeddings
37
+ emb = scale * emb
38
+
39
+ # concat sine and cosine embeddings
40
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
41
+
42
+ # flip sine and cosine embeddings
43
+ if flip_sin_to_cos:
44
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
45
+
46
+ # zero pad
47
+ if embedding_dim % 2 == 1:
48
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
49
+ return emb
50
+
51
+
52
+ # FFN
53
+ def FeedForward(dim, mult=4):
54
+ inner_dim = int(dim * mult)
55
+ return nn.Sequential(
56
+ nn.LayerNorm(dim),
57
+ nn.Linear(dim, inner_dim, bias=False),
58
+ nn.GELU(),
59
+ nn.Linear(inner_dim, dim, bias=False),
60
+ )
61
+
62
+
63
+ def reshape_tensor(x, heads):
64
+ bs, length, width = x.shape
65
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
66
+ x = x.view(bs, length, heads, -1)
67
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
68
+ x = x.transpose(1, 2)
69
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
70
+ x = x.reshape(bs, heads, length, -1)
71
+ return x
72
+
73
+
74
+ class PerceiverAttention(nn.Module):
75
+ def __init__(self, *, dim, dim_head=64, heads=8):
76
+ super().__init__()
77
+ self.scale = dim_head**-0.5
78
+ self.dim_head = dim_head
79
+ self.heads = heads
80
+ inner_dim = dim_head * heads
81
+
82
+ self.norm1 = nn.LayerNorm(dim)
83
+ self.norm2 = nn.LayerNorm(dim)
84
+
85
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
86
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
87
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
88
+
89
+
90
+ def forward(self, x, latents, shift=None, scale=None):
91
+ """
92
+ Args:
93
+ x (torch.Tensor): image features
94
+ shape (b, n1, D)
95
+ latent (torch.Tensor): latent features
96
+ shape (b, n2, D)
97
+ """
98
+ x = self.norm1(x)
99
+ latents = self.norm2(latents)
100
+
101
+ if shift is not None and scale is not None:
102
+ latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
103
+
104
+ b, l, _ = latents.shape
105
+
106
+ q = self.to_q(latents)
107
+ kv_input = torch.cat((x, latents), dim=-2)
108
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
109
+
110
+ q = reshape_tensor(q, self.heads)
111
+ k = reshape_tensor(k, self.heads)
112
+ v = reshape_tensor(v, self.heads)
113
+
114
+ # attention
115
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
116
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
117
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
118
+ out = weight @ v
119
+
120
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
121
+
122
+ return self.to_out(out)
123
+
124
+
125
+ class Resampler(nn.Module):
126
+ def __init__(
127
+ self,
128
+ dim=1024,
129
+ depth=8,
130
+ dim_head=64,
131
+ heads=16,
132
+ num_queries=8,
133
+ embedding_dim=768,
134
+ output_dim=1024,
135
+ ff_mult=4,
136
+ *args,
137
+ **kwargs,
138
+ ):
139
+ super().__init__()
140
+
141
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
142
+
143
+ self.proj_in = nn.Linear(embedding_dim, dim)
144
+
145
+ self.proj_out = nn.Linear(dim, output_dim)
146
+ self.norm_out = nn.LayerNorm(output_dim)
147
+
148
+ self.layers = nn.ModuleList([])
149
+ for _ in range(depth):
150
+ self.layers.append(
151
+ nn.ModuleList(
152
+ [
153
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
154
+ FeedForward(dim=dim, mult=ff_mult),
155
+ ]
156
+ )
157
+ )
158
+
159
+ def forward(self, x):
160
+
161
+ latents = self.latents.repeat(x.size(0), 1, 1)
162
+
163
+ x = self.proj_in(x)
164
+
165
+ for attn, ff in self.layers:
166
+ latents = attn(x, latents) + latents
167
+ latents = ff(latents) + latents
168
+
169
+ latents = self.proj_out(latents)
170
+ return self.norm_out(latents)
171
+
172
+
173
+ class TimeResampler(nn.Module):
174
+ def __init__(
175
+ self,
176
+ dim=1024,
177
+ depth=8,
178
+ dim_head=64,
179
+ heads=16,
180
+ num_queries=8,
181
+ embedding_dim=768,
182
+ output_dim=1024,
183
+ ff_mult=4,
184
+ timestep_in_dim=320,
185
+ timestep_flip_sin_to_cos=True,
186
+ timestep_freq_shift=0,
187
+ ):
188
+ super().__init__()
189
+
190
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
191
+
192
+ self.proj_in = nn.Linear(embedding_dim, dim)
193
+
194
+ self.proj_out = nn.Linear(dim, output_dim)
195
+ self.norm_out = nn.LayerNorm(output_dim)
196
+
197
+ self.layers = nn.ModuleList([])
198
+ for _ in range(depth):
199
+ self.layers.append(
200
+ nn.ModuleList(
201
+ [
202
+ # msa
203
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
204
+ # ff
205
+ FeedForward(dim=dim, mult=ff_mult),
206
+ # adaLN
207
+ nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True))
208
+ ]
209
+ )
210
+ )
211
+
212
+ # time
213
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
214
+ self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
215
+
216
+ # adaLN
217
+ # self.adaLN_modulation = nn.Sequential(
218
+ # nn.SiLU(),
219
+ # nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
220
+ # )
221
+
222
+
223
+ def forward(self, x, timestep, need_temb=False):
224
+ timestep_emb = self.embedding_time(x, timestep) # bs, dim
225
+
226
+ latents = self.latents.repeat(x.size(0), 1, 1)
227
+
228
+ x = self.proj_in(x)
229
+ x = x + timestep_emb[:, None]
230
+
231
+ for attn, ff, adaLN_modulation in self.layers:
232
+ shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
233
+ latents = attn(x, latents, shift_msa, scale_msa) + latents
234
+
235
+ res = latents
236
+ for idx_ff in range(len(ff)):
237
+ layer_ff = ff[idx_ff]
238
+ latents = layer_ff(latents)
239
+ if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
240
+ latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
241
+ latents = latents + res
242
+
243
+ # latents = ff(latents) + latents
244
+
245
+ latents = self.proj_out(latents)
246
+ latents = self.norm_out(latents)
247
+
248
+ if need_temb:
249
+ return latents, timestep_emb
250
+ else:
251
+ return latents
252
+
253
+
254
+
255
+ def embedding_time(self, sample, timestep):
256
+
257
+ # 1. time
258
+ timesteps = timestep
259
+ if not torch.is_tensor(timesteps):
260
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
261
+ # This would be a good case for the `match` statement (Python 3.10+)
262
+ is_mps = sample.device.type == "mps"
263
+ if isinstance(timestep, float):
264
+ dtype = torch.float32 if is_mps else torch.float64
265
+ else:
266
+ dtype = torch.int32 if is_mps else torch.int64
267
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
268
+ elif len(timesteps.shape) == 0:
269
+ timesteps = timesteps[None].to(sample.device)
270
+
271
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
272
+ timesteps = timesteps.expand(sample.shape[0])
273
+
274
+ t_emb = self.time_proj(timesteps)
275
+
276
+ # timesteps does not contain any weights and will always return f32 tensors
277
+ # but time_embedding might actually be running in fp16. so we need to cast here.
278
+ # there might be better ways to encapsulate this.
279
+ t_emb = t_emb.to(dtype=sample.dtype)
280
+
281
+ emb = self.time_embedding(t_emb, None)
282
+ return emb
283
+
284
+
285
+
286
+
287
+
288
+ if __name__ == '__main__':
289
+ model = TimeResampler(
290
+ dim=1280,
291
+ depth=4,
292
+ dim_head=64,
293
+ heads=20,
294
+ num_queries=16,
295
+ embedding_dim=512,
296
+ output_dim=2048,
297
+ ff_mult=4,
298
+ timestep_in_dim=320,
299
+ timestep_flip_sin_to_cos=True,
300
+ timestep_freq_shift=0,
301
+ in_channel_extra_emb=2048,
302
+ )
303
+
304
+
models/transformer_sd3.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from .attention import JointTransformerBlock
24
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormContinuous
27
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
29
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
36
+ """
37
+ The Transformer model introduced in Stable Diffusion 3.
38
+
39
+ Reference: https://arxiv.org/abs/2403.03206
40
+
41
+ Parameters:
42
+ sample_size (`int`): The width of the latent images. This is fixed during training since
43
+ it is used to learn a number of position embeddings.
44
+ patch_size (`int`): Patch size to turn the input data into small patches.
45
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
46
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
47
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
48
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
49
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
50
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
51
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
52
+ out_channels (`int`, defaults to 16): Number of output channels.
53
+
54
+ """
55
+
56
+ _supports_gradient_checkpointing = True
57
+
58
+ @register_to_config
59
+ def __init__(
60
+ self,
61
+ sample_size: int = 128,
62
+ patch_size: int = 2,
63
+ in_channels: int = 16,
64
+ num_layers: int = 18,
65
+ attention_head_dim: int = 64,
66
+ num_attention_heads: int = 18,
67
+ joint_attention_dim: int = 4096,
68
+ caption_projection_dim: int = 1152,
69
+ pooled_projection_dim: int = 2048,
70
+ out_channels: int = 16,
71
+ pos_embed_max_size: int = 96,
72
+ dual_attention_layers: Tuple[
73
+ int, ...
74
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
75
+ qk_norm: Optional[str] = None,
76
+ ):
77
+ super().__init__()
78
+ default_out_channels = in_channels
79
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
80
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
81
+
82
+ self.pos_embed = PatchEmbed(
83
+ height=self.config.sample_size,
84
+ width=self.config.sample_size,
85
+ patch_size=self.config.patch_size,
86
+ in_channels=self.config.in_channels,
87
+ embed_dim=self.inner_dim,
88
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
89
+ )
90
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
91
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
92
+ )
93
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
94
+
95
+ # `attention_head_dim` is doubled to account for the mixing.
96
+ # It needs to crafted when we get the actual checkpoints.
97
+ self.transformer_blocks = nn.ModuleList(
98
+ [
99
+ JointTransformerBlock(
100
+ dim=self.inner_dim,
101
+ num_attention_heads=self.config.num_attention_heads,
102
+ attention_head_dim=self.config.attention_head_dim,
103
+ context_pre_only=i == num_layers - 1,
104
+ qk_norm=qk_norm,
105
+ use_dual_attention=True if i in dual_attention_layers else False,
106
+ )
107
+ for i in range(self.config.num_layers)
108
+ ]
109
+ )
110
+
111
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
112
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
113
+
114
+ self.gradient_checkpointing = False
115
+
116
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
117
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
118
+ """
119
+ Sets the attention processor to use [feed forward
120
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
121
+
122
+ Parameters:
123
+ chunk_size (`int`, *optional*):
124
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
125
+ over each tensor of dim=`dim`.
126
+ dim (`int`, *optional*, defaults to `0`):
127
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
128
+ or dim=1 (sequence length).
129
+ """
130
+ if dim not in [0, 1]:
131
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
132
+
133
+ # By default chunk size is 1
134
+ chunk_size = chunk_size or 1
135
+
136
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
137
+ if hasattr(module, "set_chunk_feed_forward"):
138
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
139
+
140
+ for child in module.children():
141
+ fn_recursive_feed_forward(child, chunk_size, dim)
142
+
143
+ for module in self.children():
144
+ fn_recursive_feed_forward(module, chunk_size, dim)
145
+
146
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
147
+ def disable_forward_chunking(self):
148
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
149
+ if hasattr(module, "set_chunk_feed_forward"):
150
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
151
+
152
+ for child in module.children():
153
+ fn_recursive_feed_forward(child, chunk_size, dim)
154
+
155
+ for module in self.children():
156
+ fn_recursive_feed_forward(module, None, 0)
157
+
158
+ @property
159
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
160
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
161
+ r"""
162
+ Returns:
163
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
164
+ indexed by its weight name.
165
+ """
166
+ # set recursively
167
+ processors = {}
168
+
169
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
170
+ if hasattr(module, "get_processor"):
171
+ processors[f"{name}.processor"] = module.get_processor()
172
+
173
+ for sub_name, child in module.named_children():
174
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
175
+
176
+ return processors
177
+
178
+ for name, module in self.named_children():
179
+ fn_recursive_add_processors(name, module, processors)
180
+
181
+ return processors
182
+
183
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
184
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
185
+ r"""
186
+ Sets the attention processor to use to compute attention.
187
+
188
+ Parameters:
189
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
190
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
191
+ for **all** `Attention` layers.
192
+
193
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
194
+ processor. This is strongly recommended when setting trainable attention processors.
195
+
196
+ """
197
+ count = len(self.attn_processors.keys())
198
+
199
+ if isinstance(processor, dict) and len(processor) != count:
200
+ raise ValueError(
201
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
202
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
203
+ )
204
+
205
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
206
+ if hasattr(module, "set_processor"):
207
+ if not isinstance(processor, dict):
208
+ module.set_processor(processor)
209
+ else:
210
+ module.set_processor(processor.pop(f"{name}.processor"))
211
+
212
+ for sub_name, child in module.named_children():
213
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
214
+
215
+ for name, module in self.named_children():
216
+ fn_recursive_attn_processor(name, module, processor)
217
+
218
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
219
+ def fuse_qkv_projections(self):
220
+ """
221
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
222
+ are fused. For cross-attention modules, key and value projection matrices are fused.
223
+
224
+ <Tip warning={true}>
225
+
226
+ This API is 🧪 experimental.
227
+
228
+ </Tip>
229
+ """
230
+ self.original_attn_processors = None
231
+
232
+ for _, attn_processor in self.attn_processors.items():
233
+ if "Added" in str(attn_processor.__class__.__name__):
234
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
235
+
236
+ self.original_attn_processors = self.attn_processors
237
+
238
+ for module in self.modules():
239
+ if isinstance(module, Attention):
240
+ module.fuse_projections(fuse=True)
241
+
242
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
243
+
244
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
245
+ def unfuse_qkv_projections(self):
246
+ """Disables the fused QKV projection if enabled.
247
+
248
+ <Tip warning={true}>
249
+
250
+ This API is 🧪 experimental.
251
+
252
+ </Tip>
253
+
254
+ """
255
+ if self.original_attn_processors is not None:
256
+ self.set_attn_processor(self.original_attn_processors)
257
+
258
+ def _set_gradient_checkpointing(self, module, value=False):
259
+ if hasattr(module, "gradient_checkpointing"):
260
+ module.gradient_checkpointing = value
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.FloatTensor,
265
+ encoder_hidden_states: torch.FloatTensor = None,
266
+ pooled_projections: torch.FloatTensor = None,
267
+ timestep: torch.LongTensor = None,
268
+ block_controlnet_hidden_states: List = None,
269
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
270
+ return_dict: bool = True,
271
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
272
+ """
273
+ The [`SD3Transformer2DModel`] forward method.
274
+
275
+ Args:
276
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
277
+ Input `hidden_states`.
278
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
279
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281
+ from the embeddings of input conditions.
282
+ timestep ( `torch.LongTensor`):
283
+ Used to indicate denoising step.
284
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
285
+ A list of tensors that if specified are added to the residuals of transformer blocks.
286
+ joint_attention_kwargs (`dict`, *optional*):
287
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
288
+ `self.processor` in
289
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
290
+ return_dict (`bool`, *optional*, defaults to `True`):
291
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292
+ tuple.
293
+
294
+ Returns:
295
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
296
+ `tuple` where the first element is the sample tensor.
297
+ """
298
+ if joint_attention_kwargs is not None:
299
+ joint_attention_kwargs = joint_attention_kwargs.copy()
300
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
301
+ else:
302
+ lora_scale = 1.0
303
+
304
+ if USE_PEFT_BACKEND:
305
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
306
+ scale_lora_layers(self, lora_scale)
307
+ else:
308
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
309
+ logger.warning(
310
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
311
+ )
312
+
313
+ height, width = hidden_states.shape[-2:]
314
+
315
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
316
+ temb = self.time_text_embed(timestep, pooled_projections)
317
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
318
+
319
+ for index_block, block in enumerate(self.transformer_blocks):
320
+ if self.training and self.gradient_checkpointing:
321
+
322
+ def create_custom_forward(module, return_dict=None):
323
+ def custom_forward(*inputs):
324
+ if return_dict is not None:
325
+ return module(*inputs, return_dict=return_dict)
326
+ else:
327
+ return module(*inputs)
328
+
329
+ return custom_forward
330
+
331
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
332
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
333
+ create_custom_forward(block),
334
+ hidden_states,
335
+ encoder_hidden_states,
336
+ temb,
337
+ joint_attention_kwargs,
338
+ **ckpt_kwargs,
339
+ )
340
+
341
+ else:
342
+ encoder_hidden_states, hidden_states = block(
343
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
344
+ joint_attention_kwargs=joint_attention_kwargs,
345
+ )
346
+
347
+ # controlnet residual
348
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
349
+ interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
350
+ hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
351
+
352
+ hidden_states = self.norm_out(hidden_states, temb)
353
+ hidden_states = self.proj_out(hidden_states)
354
+
355
+ # unpatchify
356
+ patch_size = self.config.patch_size
357
+ height = height // patch_size
358
+ width = width // patch_size
359
+
360
+ hidden_states = hidden_states.reshape(
361
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
362
+ )
363
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
364
+ output = hidden_states.reshape(
365
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
366
+ )
367
+
368
+ if USE_PEFT_BACKEND:
369
+ # remove `lora_scale` from each PEFT layer
370
+ unscale_lora_layers(self, lora_scale)
371
+
372
+ if not return_dict:
373
+ return (output,)
374
+
375
+ return Transformer2DModelOutput(sample=output)