rishh76 commited on
Commit
5701d6e
1 Parent(s): 76e5e07

Upload 6 files

Browse files
models/___init__.py ADDED
File without changes
models/configuration_chatglm.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ChatGLMConfig(PretrainedConfig):
5
+ model_type = "chatglm"
6
+ def __init__(
7
+ self,
8
+ num_layers=28,
9
+ padded_vocab_size=65024,
10
+ hidden_size=4096,
11
+ ffn_hidden_size=13696,
12
+ kv_channels=128,
13
+ num_attention_heads=32,
14
+ seq_length=2048,
15
+ hidden_dropout=0.0,
16
+ classifier_dropout=None,
17
+ attention_dropout=0.0,
18
+ layernorm_epsilon=1e-5,
19
+ rmsnorm=True,
20
+ apply_residual_connection_post_layernorm=False,
21
+ post_layer_norm=True,
22
+ add_bias_linear=False,
23
+ add_qkv_bias=False,
24
+ bias_dropout_fusion=True,
25
+ multi_query_attention=False,
26
+ multi_query_group_num=1,
27
+ apply_query_key_layer_scaling=True,
28
+ attention_softmax_in_fp32=True,
29
+ fp32_residual_connection=False,
30
+ quantization_bit=0,
31
+ pre_seq_len=None,
32
+ prefix_projection=False,
33
+ **kwargs
34
+ ):
35
+ self.num_layers = num_layers
36
+ self.vocab_size = padded_vocab_size
37
+ self.padded_vocab_size = padded_vocab_size
38
+ self.hidden_size = hidden_size
39
+ self.ffn_hidden_size = ffn_hidden_size
40
+ self.kv_channels = kv_channels
41
+ self.num_attention_heads = num_attention_heads
42
+ self.seq_length = seq_length
43
+ self.hidden_dropout = hidden_dropout
44
+ self.classifier_dropout = classifier_dropout
45
+ self.attention_dropout = attention_dropout
46
+ self.layernorm_epsilon = layernorm_epsilon
47
+ self.rmsnorm = rmsnorm
48
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
49
+ self.post_layer_norm = post_layer_norm
50
+ self.add_bias_linear = add_bias_linear
51
+ self.add_qkv_bias = add_qkv_bias
52
+ self.bias_dropout_fusion = bias_dropout_fusion
53
+ self.multi_query_attention = multi_query_attention
54
+ self.multi_query_group_num = multi_query_group_num
55
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
56
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
57
+ self.fp32_residual_connection = fp32_residual_connection
58
+ self.quantization_bit = quantization_bit
59
+ self.pre_seq_len = pre_seq_len
60
+ self.prefix_projection = prefix_projection
61
+ super().__init__(**kwargs)
models/controlnet.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+
34
+ try:
35
+ from diffusers.unets.unet_2d_blocks import (
36
+ CrossAttnDownBlock2D,
37
+ DownBlock2D,
38
+ UNetMidBlock2D,
39
+ UNetMidBlock2DCrossAttn,
40
+ get_down_block,
41
+ )
42
+ from diffusers.unets.unet_2d_condition import UNet2DConditionModel
43
+ except:
44
+ from diffusers.models.unets.unet_2d_blocks import (
45
+ CrossAttnDownBlock2D,
46
+ DownBlock2D,
47
+ UNetMidBlock2D,
48
+ UNetMidBlock2DCrossAttn,
49
+ get_down_block,
50
+ )
51
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
52
+
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ @dataclass
59
+ class ControlNetOutput(BaseOutput):
60
+ """
61
+ The output of [`ControlNetModel`].
62
+
63
+ Args:
64
+ down_block_res_samples (`tuple[torch.Tensor]`):
65
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
66
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
67
+ used to condition the original UNet's downsampling activations.
68
+ mid_down_block_re_sample (`torch.Tensor`):
69
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
70
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
71
+ Output can be used to condition the original UNet's middle block activation.
72
+ """
73
+
74
+ down_block_res_samples: Tuple[torch.Tensor]
75
+ mid_block_res_sample: torch.Tensor
76
+
77
+
78
+ class ControlNetConditioningEmbedding(nn.Module):
79
+ """
80
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
81
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
82
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
83
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
84
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
85
+ model) to encode image-space conditions ... into feature maps ..."
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ conditioning_embedding_channels: int,
91
+ conditioning_channels: int = 3,
92
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
93
+ ):
94
+ super().__init__()
95
+
96
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
97
+
98
+ self.blocks = nn.ModuleList([])
99
+
100
+ for i in range(len(block_out_channels) - 1):
101
+ channel_in = block_out_channels[i]
102
+ channel_out = block_out_channels[i + 1]
103
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
104
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
105
+
106
+ self.conv_out = zero_module(
107
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
108
+ )
109
+
110
+ def forward(self, conditioning):
111
+ embedding = self.conv_in(conditioning)
112
+ embedding = F.silu(embedding)
113
+
114
+ for block in self.blocks:
115
+ embedding = block(embedding)
116
+ embedding = F.silu(embedding)
117
+
118
+ embedding = self.conv_out(embedding)
119
+
120
+ return embedding
121
+
122
+
123
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
124
+ """
125
+ A ControlNet model.
126
+
127
+ Args:
128
+ in_channels (`int`, defaults to 4):
129
+ The number of channels in the input sample.
130
+ flip_sin_to_cos (`bool`, defaults to `True`):
131
+ Whether to flip the sin to cos in the time embedding.
132
+ freq_shift (`int`, defaults to 0):
133
+ The frequency shift to apply to the time embedding.
134
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
135
+ The tuple of downsample blocks to use.
136
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
137
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
138
+ The tuple of output channels for each block.
139
+ layers_per_block (`int`, defaults to 2):
140
+ The number of layers per block.
141
+ downsample_padding (`int`, defaults to 1):
142
+ The padding to use for the downsampling convolution.
143
+ mid_block_scale_factor (`float`, defaults to 1):
144
+ The scale factor to use for the mid block.
145
+ act_fn (`str`, defaults to "silu"):
146
+ The activation function to use.
147
+ norm_num_groups (`int`, *optional*, defaults to 32):
148
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
149
+ in post-processing.
150
+ norm_eps (`float`, defaults to 1e-5):
151
+ The epsilon to use for the normalization.
152
+ cross_attention_dim (`int`, defaults to 1280):
153
+ The dimension of the cross attention features.
154
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
155
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
156
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
157
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
158
+ encoder_hid_dim (`int`, *optional*, defaults to None):
159
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
160
+ dimension to `cross_attention_dim`.
161
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
162
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
163
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
164
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
165
+ The dimension of the attention heads.
166
+ use_linear_projection (`bool`, defaults to `False`):
167
+ class_embed_type (`str`, *optional*, defaults to `None`):
168
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
169
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
170
+ addition_embed_type (`str`, *optional*, defaults to `None`):
171
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
172
+ "text". "text" will use the `TextTimeEmbedding` layer.
173
+ num_class_embeds (`int`, *optional*, defaults to 0):
174
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
175
+ class conditioning with `class_embed_type` equal to `None`.
176
+ upcast_attention (`bool`, defaults to `False`):
177
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
178
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
179
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
180
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
181
+ `class_embed_type="projection"`.
182
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
183
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
184
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
185
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
186
+ global_pool_conditions (`bool`, defaults to `False`):
187
+ TODO(Patrick) - unused parameter.
188
+ addition_embed_type_num_heads (`int`, defaults to 64):
189
+ The number of heads to use for the `TextTimeEmbedding` layer.
190
+ """
191
+
192
+ _supports_gradient_checkpointing = True
193
+
194
+ @register_to_config
195
+ def __init__(
196
+ self,
197
+ in_channels: int = 4,
198
+ conditioning_channels: int = 3,
199
+ flip_sin_to_cos: bool = True,
200
+ freq_shift: int = 0,
201
+ down_block_types: Tuple[str, ...] = (
202
+ "CrossAttnDownBlock2D",
203
+ "CrossAttnDownBlock2D",
204
+ "CrossAttnDownBlock2D",
205
+ "DownBlock2D",
206
+ ),
207
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
208
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
209
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
210
+ layers_per_block: int = 2,
211
+ downsample_padding: int = 1,
212
+ mid_block_scale_factor: float = 1,
213
+ act_fn: str = "silu",
214
+ norm_num_groups: Optional[int] = 32,
215
+ norm_eps: float = 1e-5,
216
+ cross_attention_dim: int = 1280,
217
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
218
+ encoder_hid_dim: Optional[int] = None,
219
+ encoder_hid_dim_type: Optional[str] = None,
220
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
221
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
222
+ use_linear_projection: bool = False,
223
+ class_embed_type: Optional[str] = None,
224
+ addition_embed_type: Optional[str] = None,
225
+ addition_time_embed_dim: Optional[int] = None,
226
+ num_class_embeds: Optional[int] = None,
227
+ upcast_attention: bool = False,
228
+ resnet_time_scale_shift: str = "default",
229
+ projection_class_embeddings_input_dim: Optional[int] = None,
230
+ controlnet_conditioning_channel_order: str = "rgb",
231
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
232
+ global_pool_conditions: bool = False,
233
+ addition_embed_type_num_heads: int = 64,
234
+ ):
235
+ super().__init__()
236
+
237
+ # If `num_attention_heads` is not defined (which is the case for most models)
238
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
239
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
240
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
241
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
242
+ # which is why we correct for the naming here.
243
+ num_attention_heads = num_attention_heads or attention_head_dim
244
+
245
+ # Check inputs
246
+ if len(block_out_channels) != len(down_block_types):
247
+ raise ValueError(
248
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if isinstance(transformer_layers_per_block, int):
262
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
263
+
264
+ # input
265
+ conv_in_kernel = 3
266
+ conv_in_padding = (conv_in_kernel - 1) // 2
267
+ self.conv_in = nn.Conv2d(
268
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
269
+ )
270
+
271
+ # time
272
+ time_embed_dim = block_out_channels[0] * 4
273
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
274
+ timestep_input_dim = block_out_channels[0]
275
+ self.time_embedding = TimestepEmbedding(
276
+ timestep_input_dim,
277
+ time_embed_dim,
278
+ act_fn=act_fn,
279
+ )
280
+
281
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
282
+ encoder_hid_dim_type = "text_proj"
283
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
284
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
285
+
286
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
287
+ raise ValueError(
288
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
289
+ )
290
+
291
+ if encoder_hid_dim_type == "text_proj":
292
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
293
+ elif encoder_hid_dim_type == "text_image_proj":
294
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
295
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
296
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
297
+ self.encoder_hid_proj = TextImageProjection(
298
+ text_embed_dim=encoder_hid_dim,
299
+ image_embed_dim=cross_attention_dim,
300
+ cross_attention_dim=cross_attention_dim,
301
+ )
302
+
303
+ elif encoder_hid_dim_type is not None:
304
+ raise ValueError(
305
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
306
+ )
307
+ else:
308
+ self.encoder_hid_proj = None
309
+
310
+ # class embedding
311
+ if class_embed_type is None and num_class_embeds is not None:
312
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
313
+ elif class_embed_type == "timestep":
314
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
315
+ elif class_embed_type == "identity":
316
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
317
+ elif class_embed_type == "projection":
318
+ if projection_class_embeddings_input_dim is None:
319
+ raise ValueError(
320
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
321
+ )
322
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
323
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
324
+ # 2. it projects from an arbitrary input dimension.
325
+ #
326
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
327
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
328
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
329
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
330
+ else:
331
+ self.class_embedding = None
332
+
333
+ if addition_embed_type == "text":
334
+ if encoder_hid_dim is not None:
335
+ text_time_embedding_from_dim = encoder_hid_dim
336
+ else:
337
+ text_time_embedding_from_dim = cross_attention_dim
338
+
339
+ self.add_embedding = TextTimeEmbedding(
340
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
341
+ )
342
+ elif addition_embed_type == "text_image":
343
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
344
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
345
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
346
+ self.add_embedding = TextImageTimeEmbedding(
347
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
348
+ )
349
+ elif addition_embed_type == "text_time":
350
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
351
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
352
+
353
+ elif addition_embed_type is not None:
354
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
355
+
356
+ # control net conditioning embedding
357
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
358
+ conditioning_embedding_channels=block_out_channels[0],
359
+ block_out_channels=conditioning_embedding_out_channels,
360
+ conditioning_channels=conditioning_channels,
361
+ )
362
+
363
+ self.down_blocks = nn.ModuleList([])
364
+ self.controlnet_down_blocks = nn.ModuleList([])
365
+
366
+ if isinstance(only_cross_attention, bool):
367
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
368
+
369
+ if isinstance(attention_head_dim, int):
370
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
371
+
372
+ if isinstance(num_attention_heads, int):
373
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
374
+
375
+ # down
376
+ output_channel = block_out_channels[0]
377
+
378
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
379
+ controlnet_block = zero_module(controlnet_block)
380
+ self.controlnet_down_blocks.append(controlnet_block)
381
+
382
+ for i, down_block_type in enumerate(down_block_types):
383
+ input_channel = output_channel
384
+ output_channel = block_out_channels[i]
385
+ is_final_block = i == len(block_out_channels) - 1
386
+
387
+ down_block = get_down_block(
388
+ down_block_type,
389
+ num_layers=layers_per_block,
390
+ transformer_layers_per_block=transformer_layers_per_block[i],
391
+ in_channels=input_channel,
392
+ out_channels=output_channel,
393
+ temb_channels=time_embed_dim,
394
+ add_downsample=not is_final_block,
395
+ resnet_eps=norm_eps,
396
+ resnet_act_fn=act_fn,
397
+ resnet_groups=norm_num_groups,
398
+ cross_attention_dim=cross_attention_dim,
399
+ num_attention_heads=num_attention_heads[i],
400
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
401
+ downsample_padding=downsample_padding,
402
+ use_linear_projection=use_linear_projection,
403
+ only_cross_attention=only_cross_attention[i],
404
+ upcast_attention=upcast_attention,
405
+ resnet_time_scale_shift=resnet_time_scale_shift,
406
+ )
407
+ self.down_blocks.append(down_block)
408
+
409
+ for _ in range(layers_per_block):
410
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
411
+ controlnet_block = zero_module(controlnet_block)
412
+ self.controlnet_down_blocks.append(controlnet_block)
413
+
414
+ if not is_final_block:
415
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
416
+ controlnet_block = zero_module(controlnet_block)
417
+ self.controlnet_down_blocks.append(controlnet_block)
418
+
419
+ # mid
420
+ mid_block_channel = block_out_channels[-1]
421
+
422
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
423
+ controlnet_block = zero_module(controlnet_block)
424
+ self.controlnet_mid_block = controlnet_block
425
+
426
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
427
+ self.mid_block = UNetMidBlock2DCrossAttn(
428
+ transformer_layers_per_block=transformer_layers_per_block[-1],
429
+ in_channels=mid_block_channel,
430
+ temb_channels=time_embed_dim,
431
+ resnet_eps=norm_eps,
432
+ resnet_act_fn=act_fn,
433
+ output_scale_factor=mid_block_scale_factor,
434
+ resnet_time_scale_shift=resnet_time_scale_shift,
435
+ cross_attention_dim=cross_attention_dim,
436
+ num_attention_heads=num_attention_heads[-1],
437
+ resnet_groups=norm_num_groups,
438
+ use_linear_projection=use_linear_projection,
439
+ upcast_attention=upcast_attention,
440
+ )
441
+ elif mid_block_type == "UNetMidBlock2D":
442
+ self.mid_block = UNetMidBlock2D(
443
+ in_channels=block_out_channels[-1],
444
+ temb_channels=time_embed_dim,
445
+ num_layers=0,
446
+ resnet_eps=norm_eps,
447
+ resnet_act_fn=act_fn,
448
+ output_scale_factor=mid_block_scale_factor,
449
+ resnet_groups=norm_num_groups,
450
+ resnet_time_scale_shift=resnet_time_scale_shift,
451
+ add_attention=False,
452
+ )
453
+ else:
454
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
455
+
456
+ @classmethod
457
+ def from_unet(
458
+ cls,
459
+ unet: UNet2DConditionModel,
460
+ controlnet_conditioning_channel_order: str = "rgb",
461
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
462
+ load_weights_from_unet: bool = True,
463
+ conditioning_channels: int = 3,
464
+ ):
465
+ r"""
466
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
467
+
468
+ Parameters:
469
+ unet (`UNet2DConditionModel`):
470
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
471
+ where applicable.
472
+ """
473
+ transformer_layers_per_block = (
474
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
475
+ )
476
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
477
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
478
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
479
+ addition_time_embed_dim = (
480
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
481
+ )
482
+
483
+ controlnet = cls(
484
+ encoder_hid_dim=encoder_hid_dim,
485
+ encoder_hid_dim_type=encoder_hid_dim_type,
486
+ addition_embed_type=addition_embed_type,
487
+ addition_time_embed_dim=addition_time_embed_dim,
488
+ transformer_layers_per_block=transformer_layers_per_block,
489
+ in_channels=unet.config.in_channels,
490
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
491
+ freq_shift=unet.config.freq_shift,
492
+ down_block_types=unet.config.down_block_types,
493
+ only_cross_attention=unet.config.only_cross_attention,
494
+ block_out_channels=unet.config.block_out_channels,
495
+ layers_per_block=unet.config.layers_per_block,
496
+ downsample_padding=unet.config.downsample_padding,
497
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
498
+ act_fn=unet.config.act_fn,
499
+ norm_num_groups=unet.config.norm_num_groups,
500
+ norm_eps=unet.config.norm_eps,
501
+ cross_attention_dim=unet.config.cross_attention_dim,
502
+ attention_head_dim=unet.config.attention_head_dim,
503
+ num_attention_heads=unet.config.num_attention_heads,
504
+ use_linear_projection=unet.config.use_linear_projection,
505
+ class_embed_type=unet.config.class_embed_type,
506
+ num_class_embeds=unet.config.num_class_embeds,
507
+ upcast_attention=unet.config.upcast_attention,
508
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
509
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
510
+ mid_block_type=unet.config.mid_block_type,
511
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
512
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
513
+ conditioning_channels=conditioning_channels,
514
+ )
515
+
516
+ if load_weights_from_unet:
517
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
518
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
519
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
520
+
521
+ if controlnet.class_embedding:
522
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
523
+
524
+ if hasattr(controlnet, "add_embedding"):
525
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
526
+
527
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
528
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
529
+
530
+ return controlnet
531
+
532
+ @property
533
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
534
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
535
+ r"""
536
+ Returns:
537
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
538
+ indexed by its weight name.
539
+ """
540
+ # set recursively
541
+ processors = {}
542
+
543
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
544
+ if hasattr(module, "get_processor"):
545
+ processors[f"{name}.processor"] = module.get_processor()
546
+
547
+ for sub_name, child in module.named_children():
548
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
549
+
550
+ return processors
551
+
552
+ for name, module in self.named_children():
553
+ fn_recursive_add_processors(name, module, processors)
554
+
555
+ return processors
556
+
557
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
558
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
559
+ r"""
560
+ Sets the attention processor to use to compute attention.
561
+
562
+ Parameters:
563
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
564
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
565
+ for **all** `Attention` layers.
566
+
567
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
568
+ processor. This is strongly recommended when setting trainable attention processors.
569
+
570
+ """
571
+ count = len(self.attn_processors.keys())
572
+
573
+ if isinstance(processor, dict) and len(processor) != count:
574
+ raise ValueError(
575
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
576
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
577
+ )
578
+
579
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
580
+ if hasattr(module, "set_processor"):
581
+ if not isinstance(processor, dict):
582
+ module.set_processor(processor)
583
+ else:
584
+ module.set_processor(processor.pop(f"{name}.processor"))
585
+
586
+ for sub_name, child in module.named_children():
587
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
588
+
589
+ for name, module in self.named_children():
590
+ fn_recursive_attn_processor(name, module, processor)
591
+
592
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
593
+ def set_default_attn_processor(self):
594
+ """
595
+ Disables custom attention processors and sets the default attention implementation.
596
+ """
597
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
598
+ processor = AttnAddedKVProcessor()
599
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
600
+ processor = AttnProcessor()
601
+ else:
602
+ raise ValueError(
603
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
604
+ )
605
+
606
+ self.set_attn_processor(processor)
607
+
608
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
609
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
610
+ r"""
611
+ Enable sliced attention computation.
612
+
613
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
614
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
615
+
616
+ Args:
617
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
618
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
619
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
620
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
621
+ must be a multiple of `slice_size`.
622
+ """
623
+ sliceable_head_dims = []
624
+
625
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
626
+ if hasattr(module, "set_attention_slice"):
627
+ sliceable_head_dims.append(module.sliceable_head_dim)
628
+
629
+ for child in module.children():
630
+ fn_recursive_retrieve_sliceable_dims(child)
631
+
632
+ # retrieve number of attention layers
633
+ for module in self.children():
634
+ fn_recursive_retrieve_sliceable_dims(module)
635
+
636
+ num_sliceable_layers = len(sliceable_head_dims)
637
+
638
+ if slice_size == "auto":
639
+ # half the attention head size is usually a good trade-off between
640
+ # speed and memory
641
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
642
+ elif slice_size == "max":
643
+ # make smallest slice possible
644
+ slice_size = num_sliceable_layers * [1]
645
+
646
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
647
+
648
+ if len(slice_size) != len(sliceable_head_dims):
649
+ raise ValueError(
650
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
651
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
652
+ )
653
+
654
+ for i in range(len(slice_size)):
655
+ size = slice_size[i]
656
+ dim = sliceable_head_dims[i]
657
+ if size is not None and size > dim:
658
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
659
+
660
+ # Recursively walk through all the children.
661
+ # Any children which exposes the set_attention_slice method
662
+ # gets the message
663
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
664
+ if hasattr(module, "set_attention_slice"):
665
+ module.set_attention_slice(slice_size.pop())
666
+
667
+ for child in module.children():
668
+ fn_recursive_set_attention_slice(child, slice_size)
669
+
670
+ reversed_slice_size = list(reversed(slice_size))
671
+ for module in self.children():
672
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
673
+
674
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
675
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
676
+ module.gradient_checkpointing = value
677
+
678
+ def forward(
679
+ self,
680
+ sample: torch.Tensor,
681
+ timestep: Union[torch.Tensor, float, int],
682
+ encoder_hidden_states: torch.Tensor,
683
+ controlnet_cond: torch.Tensor,
684
+ conditioning_scale: float = 1.0,
685
+ class_labels: Optional[torch.Tensor] = None,
686
+ timestep_cond: Optional[torch.Tensor] = None,
687
+ attention_mask: Optional[torch.Tensor] = None,
688
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
689
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
690
+ guess_mode: bool = False,
691
+ return_dict: bool = True,
692
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
693
+ """
694
+ The [`ControlNetModel`] forward method.
695
+
696
+ Args:
697
+ sample (`torch.Tensor`):
698
+ The noisy input tensor.
699
+ timestep (`Union[torch.Tensor, float, int]`):
700
+ The number of timesteps to denoise an input.
701
+ encoder_hidden_states (`torch.Tensor`):
702
+ The encoder hidden states.
703
+ controlnet_cond (`torch.Tensor`):
704
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
705
+ conditioning_scale (`float`, defaults to `1.0`):
706
+ The scale factor for ControlNet outputs.
707
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
708
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
709
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
710
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
711
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
712
+ embeddings.
713
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
714
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
715
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
716
+ negative values to the attention scores corresponding to "discard" tokens.
717
+ added_cond_kwargs (`dict`):
718
+ Additional conditions for the Stable Diffusion XL UNet.
719
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
720
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
721
+ guess_mode (`bool`, defaults to `False`):
722
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
723
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
724
+ return_dict (`bool`, defaults to `True`):
725
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
726
+
727
+ Returns:
728
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
729
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
730
+ returned where the first element is the sample tensor.
731
+ """
732
+ # check channel order
733
+ channel_order = self.config.controlnet_conditioning_channel_order
734
+
735
+ if channel_order == "rgb":
736
+ # in rgb order by default
737
+ ...
738
+ elif channel_order == "bgr":
739
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
740
+ else:
741
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
742
+
743
+ # prepare attention_mask
744
+ if attention_mask is not None:
745
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
746
+ attention_mask = attention_mask.unsqueeze(1)
747
+
748
+ #Todo
749
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
750
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
751
+
752
+ # 1. time
753
+ timesteps = timestep
754
+ if not torch.is_tensor(timesteps):
755
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
756
+ # This would be a good case for the `match` statement (Python 3.10+)
757
+ is_mps = sample.device.type == "mps"
758
+ if isinstance(timestep, float):
759
+ dtype = torch.float32 if is_mps else torch.float64
760
+ else:
761
+ dtype = torch.int32 if is_mps else torch.int64
762
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
763
+ elif len(timesteps.shape) == 0:
764
+ timesteps = timesteps[None].to(sample.device)
765
+
766
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
767
+ timesteps = timesteps.expand(sample.shape[0])
768
+
769
+ t_emb = self.time_proj(timesteps)
770
+
771
+ # timesteps does not contain any weights and will always return f32 tensors
772
+ # but time_embedding might actually be running in fp16. so we need to cast here.
773
+ # there might be better ways to encapsulate this.
774
+ t_emb = t_emb.to(dtype=sample.dtype)
775
+
776
+ emb = self.time_embedding(t_emb, timestep_cond)
777
+ aug_emb = None
778
+
779
+ if self.class_embedding is not None:
780
+ if class_labels is None:
781
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
782
+
783
+ if self.config.class_embed_type == "timestep":
784
+ class_labels = self.time_proj(class_labels)
785
+
786
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
787
+ emb = emb + class_emb
788
+
789
+ if self.config.addition_embed_type is not None:
790
+ if self.config.addition_embed_type == "text":
791
+ aug_emb = self.add_embedding(encoder_hidden_states)
792
+
793
+ elif self.config.addition_embed_type == "text_time":
794
+ if "text_embeds" not in added_cond_kwargs:
795
+ raise ValueError(
796
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
797
+ )
798
+ text_embeds = added_cond_kwargs.get("text_embeds")
799
+ if "time_ids" not in added_cond_kwargs:
800
+ raise ValueError(
801
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
802
+ )
803
+ time_ids = added_cond_kwargs.get("time_ids")
804
+ time_embeds = self.add_time_proj(time_ids.flatten())
805
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
806
+
807
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
808
+ add_embeds = add_embeds.to(emb.dtype)
809
+ aug_emb = self.add_embedding(add_embeds)
810
+
811
+ emb = emb + aug_emb if aug_emb is not None else emb
812
+
813
+ # 2. pre-process
814
+ sample = self.conv_in(sample)
815
+
816
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
817
+ sample = sample + controlnet_cond
818
+
819
+ # 3. down
820
+ down_block_res_samples = (sample,)
821
+ for downsample_block in self.down_blocks:
822
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
823
+ sample, res_samples = downsample_block(
824
+ hidden_states=sample,
825
+ temb=emb,
826
+ encoder_hidden_states=encoder_hidden_states,
827
+ attention_mask=attention_mask,
828
+ cross_attention_kwargs=cross_attention_kwargs,
829
+ )
830
+ else:
831
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
832
+
833
+ down_block_res_samples += res_samples
834
+
835
+ # 4. mid
836
+ if self.mid_block is not None:
837
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
838
+ sample = self.mid_block(
839
+ sample,
840
+ emb,
841
+ encoder_hidden_states=encoder_hidden_states,
842
+ attention_mask=attention_mask,
843
+ cross_attention_kwargs=cross_attention_kwargs,
844
+ )
845
+ else:
846
+ sample = self.mid_block(sample, emb)
847
+
848
+ # 5. Control net blocks
849
+
850
+ controlnet_down_block_res_samples = ()
851
+
852
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
853
+ down_block_res_sample = controlnet_block(down_block_res_sample)
854
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
855
+
856
+ down_block_res_samples = controlnet_down_block_res_samples
857
+
858
+ mid_block_res_sample = self.controlnet_mid_block(sample)
859
+
860
+ # 6. scaling
861
+ if guess_mode and not self.config.global_pool_conditions:
862
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
863
+ scales = scales * conditioning_scale
864
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
865
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
866
+ else:
867
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
868
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
869
+
870
+ if self.config.global_pool_conditions:
871
+ down_block_res_samples = [
872
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
873
+ ]
874
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
875
+
876
+ if not return_dict:
877
+ return (down_block_res_samples, mid_block_res_sample)
878
+
879
+ return ControlNetOutput(
880
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
881
+ )
882
+
883
+
884
+ def zero_module(module):
885
+ for p in module.parameters():
886
+ nn.init.zeros_(p)
887
+ return module
models/modeling_chatglm.py ADDED
@@ -0,0 +1,1298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch ChatGLM model. """
2
+
3
+ import math
4
+ import copy
5
+ import warnings
6
+ import re
7
+ import sys
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss, LayerNorm
14
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
+ from torch.nn.utils import skip_init
16
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
+ from copy import deepcopy
18
+
19
+ from transformers.modeling_outputs import (
20
+ BaseModelOutputWithPast,
21
+ CausalLMOutputWithPast,
22
+ SequenceClassifierOutputWithPast,
23
+ )
24
+ from transformers.modeling_utils import PreTrainedModel
25
+ from transformers.utils import logging
26
+ from transformers.generation.logits_process import LogitsProcessor
27
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
+
29
+ try:
30
+ from .configuration_chatglm import ChatGLMConfig
31
+ except:
32
+ from configuration_chatglm import ChatGLMConfig
33
+
34
+
35
+ # flags required to enable jit fusion kernels
36
+
37
+ if sys.platform != 'darwin':
38
+ torch._C._jit_set_profiling_mode(False)
39
+ torch._C._jit_set_profiling_executor(False)
40
+ torch._C._jit_override_can_fuse_on_cpu(True)
41
+ torch._C._jit_override_can_fuse_on_gpu(True)
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
46
+ _CONFIG_FOR_DOC = "ChatGLM6BConfig"
47
+
48
+ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
49
+ "THUDM/chatglm3-6b-base",
50
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
51
+ ]
52
+
53
+
54
+ def default_init(cls, *args, **kwargs):
55
+ return cls(*args, **kwargs)
56
+
57
+
58
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
59
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
60
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
61
+ scores.zero_()
62
+ scores[..., 5] = 5e4
63
+ return scores
64
+
65
+
66
+ class PrefixEncoder(torch.nn.Module):
67
+ """
68
+ The torch.nn model to encode the prefix
69
+ Input shape: (batch-size, prefix-length)
70
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
71
+ """
72
+
73
+ def __init__(self, config: ChatGLMConfig):
74
+ super().__init__()
75
+ self.prefix_projection = config.prefix_projection
76
+ if self.prefix_projection:
77
+ # Use a two-layer MLP to encode the prefix
78
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
79
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
80
+ self.trans = torch.nn.Sequential(
81
+ torch.nn.Linear(kv_size, config.hidden_size),
82
+ torch.nn.Tanh(),
83
+ torch.nn.Linear(config.hidden_size, kv_size)
84
+ )
85
+ else:
86
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
87
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
88
+
89
+ def forward(self, prefix: torch.Tensor):
90
+ if self.prefix_projection:
91
+ prefix_tokens = self.embedding(prefix)
92
+ past_key_values = self.trans(prefix_tokens)
93
+ else:
94
+ past_key_values = self.embedding(prefix)
95
+ return past_key_values
96
+
97
+
98
+ def split_tensor_along_last_dim(
99
+ tensor: torch.Tensor,
100
+ num_partitions: int,
101
+ contiguous_split_chunks: bool = False,
102
+ ) -> List[torch.Tensor]:
103
+ """Split a tensor along its last dimension.
104
+
105
+ Arguments:
106
+ tensor: input tensor.
107
+ num_partitions: number of partitions to split the tensor
108
+ contiguous_split_chunks: If True, make each chunk contiguous
109
+ in memory.
110
+
111
+ Returns:
112
+ A list of Tensors
113
+ """
114
+ # Get the size and dimension.
115
+ last_dim = tensor.dim() - 1
116
+ last_dim_size = tensor.size()[last_dim] // num_partitions
117
+ # Split.
118
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
119
+ # Note: torch.split does not create contiguous tensors by default.
120
+ if contiguous_split_chunks:
121
+ return tuple(chunk.contiguous() for chunk in tensor_list)
122
+
123
+ return tensor_list
124
+
125
+
126
+ class RotaryEmbedding(nn.Module):
127
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
128
+ super().__init__()
129
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
130
+ self.register_buffer("inv_freq", inv_freq)
131
+ self.dim = dim
132
+ self.original_impl = original_impl
133
+
134
+ def forward_impl(
135
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
136
+ ):
137
+ """Enhanced Transformer with Rotary Position Embedding.
138
+
139
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
140
+ transformers/rope/__init__.py. MIT License:
141
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
142
+ """
143
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
144
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
145
+
146
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
147
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
148
+
149
+ # Calculate the product of position index and $\theta_i$
150
+ idx_theta = torch.outer(seq_idx, theta).float()
151
+
152
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
153
+
154
+ # this is to mimic the behaviour of complex32, else we will get different results
155
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
156
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
157
+ return cache
158
+
159
+ def forward(self, max_seq_len, offset=0):
160
+ return self.forward_impl(
161
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
162
+ )
163
+
164
+
165
+ @torch.jit.script
166
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
167
+ # x: [sq, b, np, hn]
168
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
169
+ rot_dim = rope_cache.shape[-2] * 2
170
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
171
+ # truncate to support variable sizes
172
+ rope_cache = rope_cache[:sq]
173
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
174
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
175
+ x_out2 = torch.stack(
176
+ [
177
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
178
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
179
+ ],
180
+ -1,
181
+ )
182
+ x_out2 = x_out2.flatten(3)
183
+ return torch.cat((x_out2, x_pass), dim=-1)
184
+
185
+
186
+ class RMSNorm(torch.nn.Module):
187
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
188
+ super().__init__()
189
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
190
+ self.eps = eps
191
+
192
+ def forward(self, hidden_states: torch.Tensor):
193
+ input_dtype = hidden_states.dtype
194
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
195
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
196
+
197
+ return (self.weight * hidden_states).to(input_dtype)
198
+
199
+
200
+ class CoreAttention(torch.nn.Module):
201
+ def __init__(self, config: ChatGLMConfig, layer_number):
202
+ super(CoreAttention, self).__init__()
203
+
204
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
205
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
206
+ if self.apply_query_key_layer_scaling:
207
+ self.attention_softmax_in_fp32 = True
208
+ self.layer_number = max(1, layer_number)
209
+
210
+ projection_size = config.kv_channels * config.num_attention_heads
211
+
212
+ # Per attention head and per partition values.
213
+ self.hidden_size_per_partition = projection_size
214
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
215
+ self.num_attention_heads_per_partition = config.num_attention_heads
216
+
217
+ coeff = None
218
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
219
+ if self.apply_query_key_layer_scaling:
220
+ coeff = self.layer_number
221
+ self.norm_factor *= coeff
222
+ self.coeff = coeff
223
+
224
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
225
+
226
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
227
+ pytorch_major_version = int(torch.__version__.split('.')[0])
228
+ if pytorch_major_version >= 2:
229
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
230
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
231
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
232
+ is_causal=True)
233
+ else:
234
+ if attention_mask is not None:
235
+ attention_mask = ~attention_mask
236
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
237
+ attention_mask)
238
+ context_layer = context_layer.permute(2, 0, 1, 3)
239
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
240
+ context_layer = context_layer.reshape(*new_context_layer_shape)
241
+ else:
242
+ # Raw attention scores
243
+
244
+ # [b, np, sq, sk]
245
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
246
+
247
+ # [sq, b, np, hn] -> [sq, b * np, hn]
248
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
249
+ # [sk, b, np, hn] -> [sk, b * np, hn]
250
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
251
+
252
+ # preallocting input tensor: [b * np, sq, sk]
253
+ matmul_input_buffer = torch.empty(
254
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
255
+ device=query_layer.device
256
+ )
257
+
258
+ # Raw attention scores. [b * np, sq, sk]
259
+ matmul_result = torch.baddbmm(
260
+ matmul_input_buffer,
261
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
262
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
263
+ beta=0.0,
264
+ alpha=(1.0 / self.norm_factor),
265
+ )
266
+
267
+ # change view to [b, np, sq, sk]
268
+ attention_scores = matmul_result.view(*output_size)
269
+
270
+ # ===========================
271
+ # Attention probs and dropout
272
+ # ===========================
273
+
274
+ # attention scores and attention mask [b, np, sq, sk]
275
+ if self.attention_softmax_in_fp32:
276
+ attention_scores = attention_scores.float()
277
+ if self.coeff is not None:
278
+ attention_scores = attention_scores * self.coeff
279
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
280
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
281
+ device=attention_scores.device, dtype=torch.bool)
282
+ attention_mask.tril_()
283
+ attention_mask = ~attention_mask
284
+ if attention_mask is not None:
285
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
286
+ attention_probs = F.softmax(attention_scores, dim=-1)
287
+ attention_probs = attention_probs.type_as(value_layer)
288
+
289
+ # This is actually dropping out entire tokens to attend to, which might
290
+ # seem a bit unusual, but is taken from the original Transformer paper.
291
+ attention_probs = self.attention_dropout(attention_probs)
292
+ # =========================
293
+ # Context layer. [sq, b, hp]
294
+ # =========================
295
+
296
+ # value_layer -> context layer.
297
+ # [sk, b, np, hn] --> [b, np, sq, hn]
298
+
299
+ # context layer shape: [b, np, sq, hn]
300
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
301
+ # change view [sk, b * np, hn]
302
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
303
+ # change view [b * np, sq, sk]
304
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
305
+ # matmul: [b * np, sq, hn]
306
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
307
+ # change view [b, np, sq, hn]
308
+ context_layer = context_layer.view(*output_size)
309
+ # [b, np, sq, hn] --> [sq, b, np, hn]
310
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
311
+ # [sq, b, np, hn] --> [sq, b, hp]
312
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
313
+ context_layer = context_layer.view(*new_context_layer_shape)
314
+
315
+ return context_layer
316
+
317
+
318
+ class SelfAttention(torch.nn.Module):
319
+ """Parallel self-attention layer abstract class.
320
+
321
+ Self-attention layer takes input with size [s, b, h]
322
+ and returns output of the same size.
323
+ """
324
+
325
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
326
+ super(SelfAttention, self).__init__()
327
+ self.layer_number = max(1, layer_number)
328
+
329
+ self.projection_size = config.kv_channels * config.num_attention_heads
330
+
331
+ # Per attention head and per partition values.
332
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
333
+ self.num_attention_heads_per_partition = config.num_attention_heads
334
+
335
+ self.multi_query_attention = config.multi_query_attention
336
+ self.qkv_hidden_size = 3 * self.projection_size
337
+ if self.multi_query_attention:
338
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
339
+ self.qkv_hidden_size = (
340
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
341
+ )
342
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
343
+ bias=config.add_bias_linear or config.add_qkv_bias,
344
+ device=device, **_config_to_kwargs(config)
345
+ )
346
+
347
+ self.core_attention = CoreAttention(config, self.layer_number)
348
+
349
+ # Output.
350
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
351
+ device=device, **_config_to_kwargs(config)
352
+ )
353
+
354
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
355
+ if self.multi_query_attention:
356
+ num_attention_heads = self.num_multi_query_groups_per_partition
357
+ else:
358
+ num_attention_heads = self.num_attention_heads_per_partition
359
+ return torch.empty(
360
+ inference_max_sequence_len,
361
+ batch_size,
362
+ num_attention_heads,
363
+ self.hidden_size_per_attention_head,
364
+ dtype=dtype,
365
+ device=device,
366
+ )
367
+
368
+ def forward(
369
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
370
+ ):
371
+ # hidden_states: [sq, b, h]
372
+
373
+ # =================================================
374
+ # Pre-allocate memory for key-values for inference.
375
+ # =================================================
376
+ # =====================
377
+ # Query, Key, and Value
378
+ # =====================
379
+
380
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
381
+ mixed_x_layer = self.query_key_value(hidden_states)
382
+
383
+ if self.multi_query_attention:
384
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
385
+ [
386
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
387
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
388
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
389
+ ],
390
+ dim=-1,
391
+ )
392
+ query_layer = query_layer.view(
393
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
394
+ )
395
+ key_layer = key_layer.view(
396
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
397
+ )
398
+ value_layer = value_layer.view(
399
+ value_layer.size()[:-1]
400
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
401
+ )
402
+ else:
403
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
404
+ (self.num_attention_heads_per_partition,
405
+ 3 * self.hidden_size_per_attention_head)
406
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
407
+
408
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
409
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
410
+
411
+ # apply relative positional encoding (rotary embedding)
412
+ if rotary_pos_emb is not None:
413
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
414
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
415
+
416
+ # adjust key and value for inference
417
+ if kv_cache is not None:
418
+ cache_k, cache_v = kv_cache
419
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
420
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
421
+ if use_cache:
422
+ kv_cache = (key_layer, value_layer)
423
+ else:
424
+ kv_cache = None
425
+
426
+ if self.multi_query_attention:
427
+ key_layer = key_layer.unsqueeze(-2)
428
+ key_layer = key_layer.expand(
429
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
430
+ )
431
+ key_layer = key_layer.contiguous().view(
432
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
433
+ )
434
+ value_layer = value_layer.unsqueeze(-2)
435
+ value_layer = value_layer.expand(
436
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
437
+ )
438
+ value_layer = value_layer.contiguous().view(
439
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
440
+ )
441
+
442
+ # ==================================
443
+ # core attention computation
444
+ # ==================================
445
+
446
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
447
+
448
+ # =================
449
+ # Output. [sq, b, h]
450
+ # =================
451
+
452
+ output = self.dense(context_layer)
453
+
454
+ return output, kv_cache
455
+
456
+
457
+ def _config_to_kwargs(args):
458
+ common_kwargs = {
459
+ "dtype": args.torch_dtype,
460
+ }
461
+ return common_kwargs
462
+
463
+
464
+ class MLP(torch.nn.Module):
465
+ """MLP.
466
+
467
+ MLP will take the input with h hidden state, project it to 4*h
468
+ hidden dimension, perform nonlinear transformation, and project the
469
+ state back into h hidden dimension.
470
+ """
471
+
472
+ def __init__(self, config: ChatGLMConfig, device=None):
473
+ super(MLP, self).__init__()
474
+
475
+ self.add_bias = config.add_bias_linear
476
+
477
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
478
+ self.dense_h_to_4h = nn.Linear(
479
+ config.hidden_size,
480
+ config.ffn_hidden_size * 2,
481
+ bias=self.add_bias,
482
+ device=device,
483
+ **_config_to_kwargs(config)
484
+ )
485
+
486
+ def swiglu(x):
487
+ x = torch.chunk(x, 2, dim=-1)
488
+ return F.silu(x[0]) * x[1]
489
+
490
+ self.activation_func = swiglu
491
+
492
+ # Project back to h.
493
+ self.dense_4h_to_h = nn.Linear(
494
+ config.ffn_hidden_size,
495
+ config.hidden_size,
496
+ bias=self.add_bias,
497
+ device=device,
498
+ **_config_to_kwargs(config)
499
+ )
500
+
501
+ def forward(self, hidden_states):
502
+ # [s, b, 4hp]
503
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
504
+ intermediate_parallel = self.activation_func(intermediate_parallel)
505
+ # [s, b, h]
506
+ output = self.dense_4h_to_h(intermediate_parallel)
507
+ return output
508
+
509
+
510
+ class GLMBlock(torch.nn.Module):
511
+ """A single transformer layer.
512
+
513
+ Transformer layer takes input with size [s, b, h] and returns an
514
+ output of the same size.
515
+ """
516
+
517
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
518
+ super(GLMBlock, self).__init__()
519
+ self.layer_number = layer_number
520
+
521
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
522
+
523
+ self.fp32_residual_connection = config.fp32_residual_connection
524
+
525
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
526
+ # Layernorm on the input data.
527
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
528
+ dtype=config.torch_dtype)
529
+
530
+ # Self attention.
531
+ self.self_attention = SelfAttention(config, layer_number, device=device)
532
+ self.hidden_dropout = config.hidden_dropout
533
+
534
+ # Layernorm on the attention output
535
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
536
+ dtype=config.torch_dtype)
537
+
538
+ # MLP
539
+ self.mlp = MLP(config, device=device)
540
+
541
+ def forward(
542
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
543
+ ):
544
+ # hidden_states: [s, b, h]
545
+
546
+ # Layer norm at the beginning of the transformer layer.
547
+ layernorm_output = self.input_layernorm(hidden_states)
548
+ # Self attention.
549
+ attention_output, kv_cache = self.self_attention(
550
+ layernorm_output,
551
+ attention_mask,
552
+ rotary_pos_emb,
553
+ kv_cache=kv_cache,
554
+ use_cache=use_cache
555
+ )
556
+
557
+ # Residual connection.
558
+ if self.apply_residual_connection_post_layernorm:
559
+ residual = layernorm_output
560
+ else:
561
+ residual = hidden_states
562
+
563
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
564
+ layernorm_input = residual + layernorm_input
565
+
566
+ # Layer norm post the self attention.
567
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
568
+
569
+ # MLP.
570
+ mlp_output = self.mlp(layernorm_output)
571
+
572
+ # Second residual connection.
573
+ if self.apply_residual_connection_post_layernorm:
574
+ residual = layernorm_output
575
+ else:
576
+ residual = layernorm_input
577
+
578
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
579
+ output = residual + output
580
+
581
+ return output, kv_cache
582
+
583
+
584
+ class GLMTransformer(torch.nn.Module):
585
+ """Transformer class."""
586
+
587
+ def __init__(self, config: ChatGLMConfig, device=None):
588
+ super(GLMTransformer, self).__init__()
589
+
590
+ self.fp32_residual_connection = config.fp32_residual_connection
591
+ self.post_layer_norm = config.post_layer_norm
592
+
593
+ # Number of layers.
594
+ self.num_layers = config.num_layers
595
+
596
+ # Transformer layers.
597
+ def build_layer(layer_number):
598
+ return GLMBlock(config, layer_number, device=device)
599
+
600
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
601
+
602
+ if self.post_layer_norm:
603
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
604
+ # Final layer norm before output.
605
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
606
+ dtype=config.torch_dtype)
607
+
608
+ self.gradient_checkpointing = False
609
+
610
+ def _get_layer(self, layer_number):
611
+ return self.layers[layer_number]
612
+
613
+ def forward(
614
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
615
+ use_cache: Optional[bool] = True,
616
+ output_hidden_states: Optional[bool] = False,
617
+ ):
618
+ if not kv_caches:
619
+ kv_caches = [None for _ in range(self.num_layers)]
620
+ presents = () if use_cache else None
621
+ if self.gradient_checkpointing and self.training:
622
+ if use_cache:
623
+ logger.warning_once(
624
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
625
+ )
626
+ use_cache = False
627
+
628
+ all_self_attentions = None
629
+ all_hidden_states = () if output_hidden_states else None
630
+ for index in range(self.num_layers):
631
+ if output_hidden_states:
632
+ all_hidden_states = all_hidden_states + (hidden_states,)
633
+
634
+ layer = self._get_layer(index)
635
+ if self.gradient_checkpointing and self.training:
636
+ layer_ret = torch.utils.checkpoint.checkpoint(
637
+ layer,
638
+ hidden_states,
639
+ attention_mask,
640
+ rotary_pos_emb,
641
+ kv_caches[index],
642
+ use_cache
643
+ )
644
+ else:
645
+ layer_ret = layer(
646
+ hidden_states,
647
+ attention_mask,
648
+ rotary_pos_emb,
649
+ kv_cache=kv_caches[index],
650
+ use_cache=use_cache
651
+ )
652
+ hidden_states, kv_cache = layer_ret
653
+ if use_cache:
654
+ presents = presents + (kv_cache,)
655
+
656
+ if output_hidden_states:
657
+ all_hidden_states = all_hidden_states + (hidden_states,)
658
+
659
+ # Final layer norm.
660
+ if self.post_layer_norm:
661
+ hidden_states = self.final_layernorm(hidden_states)
662
+
663
+ return hidden_states, presents, all_hidden_states, all_self_attentions
664
+
665
+
666
+ class ChatGLMPreTrainedModel(PreTrainedModel):
667
+ """
668
+ An abstract class to handle weights initialization and
669
+ a simple interface for downloading and loading pretrained models.
670
+ """
671
+
672
+ is_parallelizable = False
673
+ supports_gradient_checkpointing = True
674
+ config_class = ChatGLMConfig
675
+ base_model_prefix = "transformer"
676
+ _no_split_modules = ["GLMBlock"]
677
+
678
+ def _init_weights(self, module: nn.Module):
679
+ """Initialize the weights."""
680
+ return
681
+
682
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
683
+ batch_size, seq_length = input_ids.shape
684
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
685
+ full_attention_mask.tril_()
686
+ past_length = 0
687
+ if past_key_values:
688
+ past_length = past_key_values[0][0].shape[0]
689
+ if past_length:
690
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
691
+ device=input_ids.device), full_attention_mask), dim=-1)
692
+ if padding_mask is not None:
693
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
694
+ if not past_length and padding_mask is not None:
695
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
696
+ full_attention_mask = (full_attention_mask < 0.5).bool()
697
+ full_attention_mask.unsqueeze_(1)
698
+ return full_attention_mask
699
+
700
+ def get_position_ids(self, input_ids, device):
701
+ batch_size, seq_length = input_ids.shape
702
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
703
+ return position_ids
704
+
705
+ def _set_gradient_checkpointing(self, module, value=False):
706
+ if isinstance(module, GLMTransformer):
707
+ module.gradient_checkpointing = value
708
+
709
+
710
+ class Embedding(torch.nn.Module):
711
+ """Language model embeddings."""
712
+
713
+ def __init__(self, config: ChatGLMConfig, device=None):
714
+ super(Embedding, self).__init__()
715
+
716
+ self.hidden_size = config.hidden_size
717
+ # Word embeddings (parallel).
718
+ self.word_embeddings = nn.Embedding(
719
+ config.padded_vocab_size,
720
+ self.hidden_size,
721
+ dtype=config.torch_dtype,
722
+ device=device
723
+ )
724
+ self.fp32_residual_connection = config.fp32_residual_connection
725
+
726
+ def forward(self, input_ids):
727
+ # Embeddings.
728
+ words_embeddings = self.word_embeddings(input_ids)
729
+ embeddings = words_embeddings
730
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
731
+ embeddings = embeddings.transpose(0, 1).contiguous()
732
+ # If the input flag for fp32 residual connection is set, convert for float.
733
+ if self.fp32_residual_connection:
734
+ embeddings = embeddings.float()
735
+ return embeddings
736
+
737
+
738
+ class ChatGLMModel(ChatGLMPreTrainedModel):
739
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
740
+ super().__init__(config)
741
+ if empty_init:
742
+ init_method = skip_init
743
+ else:
744
+ init_method = default_init
745
+ init_kwargs = {}
746
+ if device is not None:
747
+ init_kwargs["device"] = device
748
+ self.embedding = init_method(Embedding, config, **init_kwargs)
749
+ self.num_layers = config.num_layers
750
+ self.multi_query_group_num = config.multi_query_group_num
751
+ self.kv_channels = config.kv_channels
752
+
753
+ # Rotary positional embeddings
754
+ self.seq_length = config.seq_length
755
+ rotary_dim = (
756
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
757
+ )
758
+
759
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
760
+ dtype=config.torch_dtype)
761
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
762
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
763
+ dtype=config.torch_dtype, **init_kwargs)
764
+ self.pre_seq_len = config.pre_seq_len
765
+ self.prefix_projection = config.prefix_projection
766
+ if self.pre_seq_len is not None:
767
+ for param in self.parameters():
768
+ param.requires_grad = False
769
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
770
+ self.prefix_encoder = PrefixEncoder(config)
771
+ self.dropout = torch.nn.Dropout(0.1)
772
+
773
+ def get_input_embeddings(self):
774
+ return self.embedding.word_embeddings
775
+
776
+ def get_prompt(self, batch_size, device, dtype=torch.half):
777
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
778
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
779
+ past_key_values = past_key_values.view(
780
+ batch_size,
781
+ self.pre_seq_len,
782
+ self.num_layers * 2,
783
+ self.multi_query_group_num,
784
+ self.kv_channels
785
+ )
786
+ # seq_len, b, nh, hidden_size
787
+ past_key_values = self.dropout(past_key_values)
788
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
789
+ return past_key_values
790
+
791
+ def forward(
792
+ self,
793
+ input_ids,
794
+ position_ids: Optional[torch.Tensor] = None,
795
+ attention_mask: Optional[torch.BoolTensor] = None,
796
+ full_attention_mask: Optional[torch.BoolTensor] = None,
797
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
798
+ inputs_embeds: Optional[torch.Tensor] = None,
799
+ use_cache: Optional[bool] = None,
800
+ output_hidden_states: Optional[bool] = None,
801
+ return_dict: Optional[bool] = None,
802
+ ):
803
+ output_hidden_states = (
804
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
805
+ )
806
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
807
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
+
809
+ batch_size, seq_length = input_ids.shape
810
+
811
+ if inputs_embeds is None:
812
+ inputs_embeds = self.embedding(input_ids)
813
+
814
+ if self.pre_seq_len is not None:
815
+ if past_key_values is None:
816
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
817
+ dtype=inputs_embeds.dtype)
818
+ if attention_mask is not None:
819
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
820
+ attention_mask], dim=-1)
821
+
822
+ if full_attention_mask is None:
823
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
824
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
825
+
826
+ # Rotary positional embeddings
827
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
828
+ if position_ids is not None:
829
+ rotary_pos_emb = rotary_pos_emb[position_ids]
830
+ else:
831
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
832
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
833
+
834
+ # Run encoder.
835
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
836
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
837
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
838
+ )
839
+
840
+ if not return_dict:
841
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
842
+
843
+ return BaseModelOutputWithPast(
844
+ last_hidden_state=hidden_states,
845
+ past_key_values=presents,
846
+ hidden_states=all_hidden_states,
847
+ attentions=all_self_attentions,
848
+ )
849
+
850
+ def quantize(self, weight_bit_width: int):
851
+ from .quantization import quantize
852
+ quantize(self.encoder, weight_bit_width)
853
+ return self
854
+
855
+
856
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
857
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
858
+ super().__init__(config)
859
+
860
+ self.max_sequence_length = config.max_length
861
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
862
+ self.config = config
863
+ self.quantized = False
864
+
865
+ if self.config.quantization_bit:
866
+ self.quantize(self.config.quantization_bit, empty_init=True)
867
+
868
+ def _update_model_kwargs_for_generation(
869
+ self,
870
+ outputs: ModelOutput,
871
+ model_kwargs: Dict[str, Any],
872
+ is_encoder_decoder: bool = False,
873
+ standardize_cache_format: bool = False,
874
+ ) -> Dict[str, Any]:
875
+ # update past_key_values
876
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
877
+ outputs, standardize_cache_format=standardize_cache_format
878
+ )
879
+
880
+ # update attention mask
881
+ if "attention_mask" in model_kwargs:
882
+ attention_mask = model_kwargs["attention_mask"]
883
+ model_kwargs["attention_mask"] = torch.cat(
884
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
885
+ )
886
+
887
+ # update position ids
888
+ if "position_ids" in model_kwargs:
889
+ position_ids = model_kwargs["position_ids"]
890
+ new_position_id = position_ids[..., -1:].clone()
891
+ new_position_id += 1
892
+ model_kwargs["position_ids"] = torch.cat(
893
+ [position_ids, new_position_id], dim=-1
894
+ )
895
+
896
+ model_kwargs["is_first_forward"] = False
897
+ return model_kwargs
898
+
899
+ def prepare_inputs_for_generation(
900
+ self,
901
+ input_ids: torch.LongTensor,
902
+ past_key_values: Optional[torch.Tensor] = None,
903
+ attention_mask: Optional[torch.Tensor] = None,
904
+ position_ids: Optional[torch.Tensor] = None,
905
+ use_cache: Optional[bool] = None,
906
+ is_first_forward: bool = True,
907
+ **kwargs
908
+ ) -> dict:
909
+ # only last token for input_ids if past is not None
910
+ if position_ids is None:
911
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
912
+ if not is_first_forward:
913
+ if past_key_values is not None:
914
+ position_ids = position_ids[..., -1:]
915
+ input_ids = input_ids[:, -1:]
916
+ return {
917
+ "input_ids": input_ids,
918
+ "past_key_values": past_key_values,
919
+ "position_ids": position_ids,
920
+ "attention_mask": attention_mask,
921
+ "return_last_logit": True,
922
+ "use_cache": use_cache
923
+ }
924
+
925
+ def forward(
926
+ self,
927
+ input_ids: Optional[torch.Tensor] = None,
928
+ position_ids: Optional[torch.Tensor] = None,
929
+ attention_mask: Optional[torch.Tensor] = None,
930
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
931
+ inputs_embeds: Optional[torch.Tensor] = None,
932
+ labels: Optional[torch.Tensor] = None,
933
+ use_cache: Optional[bool] = None,
934
+ output_attentions: Optional[bool] = None,
935
+ output_hidden_states: Optional[bool] = None,
936
+ return_dict: Optional[bool] = None,
937
+ return_last_logit: Optional[bool] = False,
938
+ ):
939
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
940
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
941
+
942
+ transformer_outputs = self.transformer(
943
+ input_ids=input_ids,
944
+ position_ids=position_ids,
945
+ attention_mask=attention_mask,
946
+ past_key_values=past_key_values,
947
+ inputs_embeds=inputs_embeds,
948
+ use_cache=use_cache,
949
+ output_hidden_states=output_hidden_states,
950
+ return_dict=return_dict,
951
+ )
952
+
953
+ hidden_states = transformer_outputs[0]
954
+ if return_last_logit:
955
+ hidden_states = hidden_states[-1:]
956
+ lm_logits = self.transformer.output_layer(hidden_states)
957
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
958
+
959
+ loss = None
960
+ if labels is not None:
961
+ lm_logits = lm_logits.to(torch.float32)
962
+
963
+ # Shift so that tokens < n predict n
964
+ shift_logits = lm_logits[..., :-1, :].contiguous()
965
+ shift_labels = labels[..., 1:].contiguous()
966
+ # Flatten the tokens
967
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
968
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
969
+
970
+ lm_logits = lm_logits.to(hidden_states.dtype)
971
+ loss = loss.to(hidden_states.dtype)
972
+
973
+ if not return_dict:
974
+ output = (lm_logits,) + transformer_outputs[1:]
975
+ return ((loss,) + output) if loss is not None else output
976
+
977
+ return CausalLMOutputWithPast(
978
+ loss=loss,
979
+ logits=lm_logits,
980
+ past_key_values=transformer_outputs.past_key_values,
981
+ hidden_states=transformer_outputs.hidden_states,
982
+ attentions=transformer_outputs.attentions,
983
+ )
984
+
985
+ @staticmethod
986
+ def _reorder_cache(
987
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
988
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
989
+ """
990
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
991
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
992
+ beam_idx at every generation step.
993
+
994
+ Output shares the same memory storage as `past`.
995
+ """
996
+ return tuple(
997
+ (
998
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
999
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1000
+ )
1001
+ for layer_past in past
1002
+ )
1003
+
1004
+ def process_response(self, output, history):
1005
+ content = ""
1006
+ history = deepcopy(history)
1007
+ for response in output.split("<|assistant|>"):
1008
+ metadata, content = response.split("\n", maxsplit=1)
1009
+ if not metadata.strip():
1010
+ content = content.strip()
1011
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1012
+ content = content.replace("[[训练时间]]", "2023年")
1013
+ else:
1014
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1015
+ if history[0]["role"] == "system" and "tools" in history[0]:
1016
+ content = "\n".join(content.split("\n")[1:-1])
1017
+ def tool_call(**kwargs):
1018
+ return kwargs
1019
+ parameters = eval(content)
1020
+ content = {"name": metadata.strip(), "parameters": parameters}
1021
+ else:
1022
+ content = {"name": metadata.strip(), "content": content}
1023
+ return content, history
1024
+
1025
+ @torch.inference_mode()
1026
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1027
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1028
+ **kwargs):
1029
+ if history is None:
1030
+ history = []
1031
+ if logits_processor is None:
1032
+ logits_processor = LogitsProcessorList()
1033
+ logits_processor.append(InvalidScoreLogitsProcessor())
1034
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1035
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1036
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1037
+ inputs = inputs.to(self.device)
1038
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1039
+ tokenizer.get_command("<|observation|>")]
1040
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1041
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1042
+ response = tokenizer.decode(outputs)
1043
+ history.append({"role": role, "content": query})
1044
+ response, history = self.process_response(response, history)
1045
+ return response, history
1046
+
1047
+ @torch.inference_mode()
1048
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1049
+ past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1050
+ logits_processor=None, return_past_key_values=False, **kwargs):
1051
+ if history is None:
1052
+ history = []
1053
+ if logits_processor is None:
1054
+ logits_processor = LogitsProcessorList()
1055
+ logits_processor.append(InvalidScoreLogitsProcessor())
1056
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1057
+ tokenizer.get_command("<|observation|>")]
1058
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1059
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1060
+ if past_key_values is None:
1061
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1062
+ else:
1063
+ inputs = tokenizer.build_chat_input(query, role=role)
1064
+ inputs = inputs.to(self.device)
1065
+ if past_key_values is not None:
1066
+ past_length = past_key_values[0][0].shape[0]
1067
+ if self.transformer.pre_seq_len is not None:
1068
+ past_length -= self.transformer.pre_seq_len
1069
+ inputs.position_ids += past_length
1070
+ attention_mask = inputs.attention_mask
1071
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1072
+ inputs['attention_mask'] = attention_mask
1073
+ history.append({"role": role, "content": query})
1074
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1075
+ eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1076
+ **gen_kwargs):
1077
+ if return_past_key_values:
1078
+ outputs, past_key_values = outputs
1079
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1080
+ response = tokenizer.decode(outputs)
1081
+ if response and response[-1] != "�":
1082
+ response, new_history = self.process_response(response, history)
1083
+ if return_past_key_values:
1084
+ yield response, new_history, past_key_values
1085
+ else:
1086
+ yield response, new_history
1087
+
1088
+ @torch.inference_mode()
1089
+ def stream_generate(
1090
+ self,
1091
+ input_ids,
1092
+ generation_config: Optional[GenerationConfig] = None,
1093
+ logits_processor: Optional[LogitsProcessorList] = None,
1094
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1095
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1096
+ return_past_key_values=False,
1097
+ **kwargs,
1098
+ ):
1099
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1100
+
1101
+ if generation_config is None:
1102
+ generation_config = self.generation_config
1103
+ generation_config = copy.deepcopy(generation_config)
1104
+ model_kwargs = generation_config.update(**kwargs)
1105
+ model_kwargs["use_cache"] = generation_config.use_cache
1106
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1107
+
1108
+ if isinstance(eos_token_id, int):
1109
+ eos_token_id = [eos_token_id]
1110
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1111
+
1112
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1113
+ if has_default_max_length and generation_config.max_new_tokens is None:
1114
+ warnings.warn(
1115
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1116
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1117
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1118
+ UserWarning,
1119
+ )
1120
+ elif generation_config.max_new_tokens is not None:
1121
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1122
+ if not has_default_max_length:
1123
+ logger.warn(
1124
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1125
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1126
+ "Please refer to the documentation for more information. "
1127
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1128
+ UserWarning,
1129
+ )
1130
+
1131
+ if input_ids_seq_length >= generation_config.max_length:
1132
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1133
+ logger.warning(
1134
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1135
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1136
+ " increasing `max_new_tokens`."
1137
+ )
1138
+
1139
+ # 2. Set generation parameters if not already defined
1140
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1141
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1142
+
1143
+ logits_processor = self._get_logits_processor(
1144
+ generation_config=generation_config,
1145
+ input_ids_seq_length=input_ids_seq_length,
1146
+ encoder_input_ids=input_ids,
1147
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1148
+ logits_processor=logits_processor,
1149
+ )
1150
+
1151
+ stopping_criteria = self._get_stopping_criteria(
1152
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1153
+ )
1154
+ logits_warper = self._get_logits_warper(generation_config)
1155
+
1156
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1157
+ scores = None
1158
+ while True:
1159
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1160
+ # forward pass to get next token
1161
+ outputs = self(
1162
+ **model_inputs,
1163
+ return_dict=True,
1164
+ output_attentions=False,
1165
+ output_hidden_states=False,
1166
+ )
1167
+
1168
+ next_token_logits = outputs.logits[:, -1, :]
1169
+
1170
+ # pre-process distribution
1171
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1172
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1173
+
1174
+ # sample
1175
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1176
+ if generation_config.do_sample:
1177
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1178
+ else:
1179
+ next_tokens = torch.argmax(probs, dim=-1)
1180
+ # update generated ids, model inputs, and length for next step
1181
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1182
+ model_kwargs = self._update_model_kwargs_for_generation(
1183
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1184
+ )
1185
+ unfinished_sequences = unfinished_sequences.mul(
1186
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1187
+ )
1188
+ if return_past_key_values:
1189
+ yield input_ids, outputs.past_key_values
1190
+ else:
1191
+ yield input_ids
1192
+ # stop when each sentence is finished, or if we exceed the maximum length
1193
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1194
+ break
1195
+
1196
+ def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1197
+ if bits == 0:
1198
+ return
1199
+
1200
+ from .quantization import quantize
1201
+
1202
+ if self.quantized:
1203
+ logger.info("Already quantized.")
1204
+ return self
1205
+
1206
+ self.quantized = True
1207
+
1208
+ self.config.quantization_bit = bits
1209
+
1210
+ self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1211
+ **kwargs)
1212
+ return self
1213
+
1214
+
1215
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1216
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1217
+ super().__init__(config)
1218
+
1219
+ self.num_labels = config.num_labels
1220
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1221
+
1222
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1223
+ if config.classifier_dropout is not None:
1224
+ self.dropout = nn.Dropout(config.classifier_dropout)
1225
+ else:
1226
+ self.dropout = None
1227
+ self.config = config
1228
+
1229
+ if self.config.quantization_bit:
1230
+ self.quantize(self.config.quantization_bit, empty_init=True)
1231
+
1232
+ def forward(
1233
+ self,
1234
+ input_ids: Optional[torch.LongTensor] = None,
1235
+ position_ids: Optional[torch.LongTensor] = None,
1236
+ attention_mask: Optional[torch.Tensor] = None,
1237
+ full_attention_mask: Optional[torch.Tensor] = None,
1238
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1239
+ inputs_embeds: Optional[torch.LongTensor] = None,
1240
+ labels: Optional[torch.LongTensor] = None,
1241
+ use_cache: Optional[bool] = None,
1242
+ output_hidden_states: Optional[bool] = None,
1243
+ return_dict: Optional[bool] = None,
1244
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1245
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1246
+
1247
+ transformer_outputs = self.transformer(
1248
+ input_ids=input_ids,
1249
+ position_ids=position_ids,
1250
+ attention_mask=attention_mask,
1251
+ full_attention_mask=full_attention_mask,
1252
+ past_key_values=past_key_values,
1253
+ inputs_embeds=inputs_embeds,
1254
+ use_cache=use_cache,
1255
+ output_hidden_states=output_hidden_states,
1256
+ return_dict=return_dict,
1257
+ )
1258
+
1259
+ hidden_states = transformer_outputs[0]
1260
+ pooled_hidden_states = hidden_states[-1]
1261
+ if self.dropout is not None:
1262
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1263
+ logits = self.classifier_head(pooled_hidden_states)
1264
+
1265
+ loss = None
1266
+ if labels is not None:
1267
+ if self.config.problem_type is None:
1268
+ if self.num_labels == 1:
1269
+ self.config.problem_type = "regression"
1270
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1271
+ self.config.problem_type = "single_label_classification"
1272
+ else:
1273
+ self.config.problem_type = "multi_label_classification"
1274
+
1275
+ if self.config.problem_type == "regression":
1276
+ loss_fct = MSELoss()
1277
+ if self.num_labels == 1:
1278
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1279
+ else:
1280
+ loss = loss_fct(logits.float(), labels)
1281
+ elif self.config.problem_type == "single_label_classification":
1282
+ loss_fct = CrossEntropyLoss()
1283
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1284
+ elif self.config.problem_type == "multi_label_classification":
1285
+ loss_fct = BCEWithLogitsLoss()
1286
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1287
+
1288
+ if not return_dict:
1289
+ output = (logits,) + transformer_outputs[1:]
1290
+ return ((loss,) + output) if loss is not None else output
1291
+
1292
+ return SequenceClassifierOutputWithPast(
1293
+ loss=loss,
1294
+ logits=logits,
1295
+ past_key_values=transformer_outputs.past_key_values,
1296
+ hidden_states=transformer_outputs.hidden_states,
1297
+ attentions=transformer_outputs.attentions,
1298
+ )
models/tokenization_chatglm.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from typing import List, Optional, Union, Dict
5
+ from sentencepiece import SentencePieceProcessor
6
+ from transformers import PreTrainedTokenizer
7
+ from transformers.utils import logging, PaddingStrategy
8
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
9
+
10
+
11
+ class SPTokenizer:
12
+ def __init__(self, model_path: str):
13
+ # reload tokenizer
14
+ assert os.path.isfile(model_path), model_path
15
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
16
+
17
+ # BOS / EOS token IDs
18
+ self.n_words: int = self.sp_model.vocab_size()
19
+ self.bos_id: int = self.sp_model.bos_id()
20
+ self.eos_id: int = self.sp_model.eos_id()
21
+ self.pad_id: int = self.sp_model.unk_id()
22
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
23
+
24
+ role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
25
+ special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
26
+ self.special_tokens = {}
27
+ self.index_special_tokens = {}
28
+ for token in special_tokens:
29
+ self.special_tokens[token] = self.n_words
30
+ self.index_special_tokens[self.n_words] = token
31
+ self.n_words += 1
32
+ self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
33
+
34
+ def tokenize(self, s: str, encode_special_tokens=False):
35
+ if encode_special_tokens:
36
+ last_index = 0
37
+ t = []
38
+ for match in re.finditer(self.role_special_token_expression, s):
39
+ if last_index < match.start():
40
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
41
+ t.append(s[match.start():match.end()])
42
+ last_index = match.end()
43
+ if last_index < len(s):
44
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
45
+ return t
46
+ else:
47
+ return self.sp_model.EncodeAsPieces(s)
48
+
49
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
50
+ assert type(s) is str
51
+ t = self.sp_model.encode(s)
52
+ if bos:
53
+ t = [self.bos_id] + t
54
+ if eos:
55
+ t = t + [self.eos_id]
56
+ return t
57
+
58
+ def decode(self, t: List[int]) -> str:
59
+ text, buffer = "", []
60
+ for token in t:
61
+ if token in self.index_special_tokens:
62
+ if buffer:
63
+ text += self.sp_model.decode(buffer)
64
+ buffer = []
65
+ text += self.index_special_tokens[token]
66
+ else:
67
+ buffer.append(token)
68
+ if buffer:
69
+ text += self.sp_model.decode(buffer)
70
+ return text
71
+
72
+ def decode_tokens(self, tokens: List[str]) -> str:
73
+ text = self.sp_model.DecodePieces(tokens)
74
+ return text
75
+
76
+ def convert_token_to_id(self, token):
77
+ """ Converts a token (str) in an id using the vocab. """
78
+ if token in self.special_tokens:
79
+ return self.special_tokens[token]
80
+ return self.sp_model.PieceToId(token)
81
+
82
+ def convert_id_to_token(self, index):
83
+ """Converts an index (integer) in a token (str) using the vocab."""
84
+ if index in self.index_special_tokens:
85
+ return self.index_special_tokens[index]
86
+ if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
87
+ return ""
88
+ return self.sp_model.IdToPiece(index)
89
+
90
+
91
+ class ChatGLMTokenizer(PreTrainedTokenizer):
92
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
93
+
94
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
95
+
96
+ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
97
+ **kwargs):
98
+ self.name = "GLMTokenizer"
99
+
100
+ self.vocab_file = vocab_file
101
+ self.tokenizer = SPTokenizer(vocab_file)
102
+ self.special_tokens = {
103
+ "<bos>": self.tokenizer.bos_id,
104
+ "<eos>": self.tokenizer.eos_id,
105
+ "<pad>": self.tokenizer.pad_id
106
+ }
107
+ self.encode_special_tokens = encode_special_tokens
108
+ super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
109
+ encode_special_tokens=encode_special_tokens,
110
+ **kwargs)
111
+
112
+ def get_command(self, token):
113
+ if token in self.special_tokens:
114
+ return self.special_tokens[token]
115
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
116
+ return self.tokenizer.special_tokens[token]
117
+
118
+ @property
119
+ def unk_token(self) -> str:
120
+ return "<unk>"
121
+
122
+ @property
123
+ def pad_token(self) -> str:
124
+ return "<unk>"
125
+
126
+ @property
127
+ def pad_token_id(self):
128
+ return self.get_command("<pad>")
129
+
130
+ @property
131
+ def eos_token(self) -> str:
132
+ return "</s>"
133
+
134
+ @property
135
+ def eos_token_id(self):
136
+ return self.get_command("<eos>")
137
+
138
+ @property
139
+ def vocab_size(self):
140
+ return self.tokenizer.n_words
141
+
142
+ def get_vocab(self):
143
+ """ Returns vocab as a dict """
144
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
145
+ vocab.update(self.added_tokens_encoder)
146
+ return vocab
147
+
148
+ def _tokenize(self, text, **kwargs):
149
+ return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
150
+
151
+ def _convert_token_to_id(self, token):
152
+ """ Converts a token (str) in an id using the vocab. """
153
+ return self.tokenizer.convert_token_to_id(token)
154
+
155
+ def _convert_id_to_token(self, index):
156
+ """Converts an index (integer) in a token (str) using the vocab."""
157
+ return self.tokenizer.convert_id_to_token(index)
158
+
159
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
160
+ return self.tokenizer.decode_tokens(tokens)
161
+
162
+ def save_vocabulary(self, save_directory, filename_prefix=None):
163
+ """
164
+ Save the vocabulary and special tokens file to a directory.
165
+
166
+ Args:
167
+ save_directory (`str`):
168
+ The directory in which to save the vocabulary.
169
+ filename_prefix (`str`, *optional*):
170
+ An optional prefix to add to the named of the saved files.
171
+
172
+ Returns:
173
+ `Tuple(str)`: Paths to the files saved.
174
+ """
175
+ if os.path.isdir(save_directory):
176
+ vocab_file = os.path.join(
177
+ save_directory, self.vocab_files_names["vocab_file"]
178
+ )
179
+ else:
180
+ vocab_file = save_directory
181
+
182
+ with open(self.vocab_file, 'rb') as fin:
183
+ proto_str = fin.read()
184
+
185
+ with open(vocab_file, "wb") as writer:
186
+ writer.write(proto_str)
187
+
188
+ return (vocab_file,)
189
+
190
+ def get_prefix_tokens(self):
191
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
192
+ return prefix_tokens
193
+
194
+ def build_single_message(self, role, metadata, message):
195
+ assert role in ["system", "user", "assistant", "observation"], role
196
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
197
+ message_tokens = self.tokenizer.encode(message)
198
+ tokens = role_tokens + message_tokens
199
+ return tokens
200
+
201
+ def build_chat_input(self, query, history=None, role="user"):
202
+ if history is None:
203
+ history = []
204
+ input_ids = []
205
+ for item in history:
206
+ content = item["content"]
207
+ if item["role"] == "system" and "tools" in item:
208
+ content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
209
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
210
+ input_ids.extend(self.build_single_message(role, "", query))
211
+ input_ids.extend([self.get_command("<|assistant|>")])
212
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
213
+
214
+ def build_inputs_with_special_tokens(
215
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
216
+ ) -> List[int]:
217
+ """
218
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
219
+ adding special tokens. A BERT sequence has the following format:
220
+
221
+ - single sequence: `[CLS] X [SEP]`
222
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
223
+
224
+ Args:
225
+ token_ids_0 (`List[int]`):
226
+ List of IDs to which the special tokens will be added.
227
+ token_ids_1 (`List[int]`, *optional*):
228
+ Optional second list of IDs for sequence pairs.
229
+
230
+ Returns:
231
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
232
+ """
233
+ prefix_tokens = self.get_prefix_tokens()
234
+ token_ids_0 = prefix_tokens + token_ids_0
235
+ if token_ids_1 is not None:
236
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
237
+ return token_ids_0
238
+
239
+ def _pad(
240
+ self,
241
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
242
+ max_length: Optional[int] = None,
243
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
244
+ pad_to_multiple_of: Optional[int] = None,
245
+ return_attention_mask: Optional[bool] = None,
246
+ ) -> dict:
247
+ """
248
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
249
+
250
+ Args:
251
+ encoded_inputs:
252
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
253
+ max_length: maximum length of the returned list and optionally padding length (see below).
254
+ Will truncate by taking into account the special tokens.
255
+ padding_strategy: PaddingStrategy to use for padding.
256
+
257
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
258
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
259
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
260
+ The tokenizer padding sides are defined in self.padding_side:
261
+
262
+ - 'left': pads on the left of the sequences
263
+ - 'right': pads on the right of the sequences
264
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
265
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
266
+ `>= 7.5` (Volta).
267
+ return_attention_mask:
268
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
269
+ """
270
+ # Load from model defaults
271
+ assert self.padding_side == "left"
272
+
273
+ required_input = encoded_inputs[self.model_input_names[0]]
274
+ seq_length = len(required_input)
275
+
276
+ if padding_strategy == PaddingStrategy.LONGEST:
277
+ max_length = len(required_input)
278
+
279
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
280
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
281
+
282
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
283
+
284
+ # Initialize attention mask if not present.
285
+ if "attention_mask" not in encoded_inputs:
286
+ encoded_inputs["attention_mask"] = [1] * seq_length
287
+
288
+ if "position_ids" not in encoded_inputs:
289
+ encoded_inputs["position_ids"] = list(range(seq_length))
290
+
291
+ if needs_to_be_padded:
292
+ difference = max_length - len(required_input)
293
+
294
+ if "attention_mask" in encoded_inputs:
295
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
296
+ if "position_ids" in encoded_inputs:
297
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
298
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
299
+
300
+ return encoded_inputs
models/unet_2d_condition.py ADDED
@@ -0,0 +1,1318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
24
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ Attention,
30
+ AttentionProcessor,
31
+ AttnAddedKVProcessor,
32
+ AttnProcessor,
33
+ )
34
+ from diffusers.models.embeddings import (
35
+ GaussianFourierProjection,
36
+ GLIGENTextBoundingboxProjection,
37
+ ImageHintTimeEmbedding,
38
+ ImageProjection,
39
+ ImageTimeEmbedding,
40
+ TextImageProjection,
41
+ TextImageTimeEmbedding,
42
+ TextTimeEmbedding,
43
+ TimestepEmbedding,
44
+ Timesteps,
45
+ )
46
+ from diffusers.models.modeling_utils import ModelMixin
47
+
48
+ try:
49
+ from diffusers.models.unet_2d_blocks import (
50
+ get_down_block,
51
+ get_mid_block,
52
+ get_up_block,
53
+ )
54
+ except:
55
+ from diffusers.models.unets.unet_2d_blocks import (
56
+ get_down_block,
57
+ get_mid_block,
58
+ get_up_block,
59
+ )
60
+
61
+
62
+
63
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
64
+
65
+
66
+ @dataclass
67
+ class UNet2DConditionOutput(BaseOutput):
68
+ """
69
+ The output of [`UNet2DConditionModel`].
70
+
71
+ Args:
72
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
73
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
74
+ """
75
+
76
+ sample: torch.Tensor = None
77
+
78
+
79
+ class UNet2DConditionModel(
80
+ ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
81
+ ):
82
+ r"""
83
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
84
+ shaped output.
85
+
86
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
87
+ for all models (such as downloading or saving).
88
+
89
+ Parameters:
90
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
91
+ Height and width of input/output sample.
92
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
93
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
94
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
95
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
96
+ Whether to flip the sin to cos in the time embedding.
97
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
98
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
99
+ The tuple of downsample blocks to use.
100
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
101
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
102
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
103
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
104
+ The tuple of upsample blocks to use.
105
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
106
+ Whether to include self-attention in the basic transformer blocks, see
107
+ [`~models.attention.BasicTransformerBlock`].
108
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
109
+ The tuple of output channels for each block.
110
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
111
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
112
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
113
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
114
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
115
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
116
+ If `None`, normalization and activation layers is skipped in post-processing.
117
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
118
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
119
+ The dimension of the cross attention features.
120
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
121
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
122
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
123
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
124
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
125
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
126
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
127
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
128
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
129
+ encoder_hid_dim (`int`, *optional*, defaults to None):
130
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
131
+ dimension to `cross_attention_dim`.
132
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
133
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
134
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
135
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
136
+ num_attention_heads (`int`, *optional*):
137
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
138
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
139
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
140
+ class_embed_type (`str`, *optional*, defaults to `None`):
141
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
142
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
143
+ addition_embed_type (`str`, *optional*, defaults to `None`):
144
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
145
+ "text". "text" will use the `TextTimeEmbedding` layer.
146
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
147
+ Dimension for the timestep embeddings.
148
+ num_class_embeds (`int`, *optional*, defaults to `None`):
149
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
150
+ class conditioning with `class_embed_type` equal to `None`.
151
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
152
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
153
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
154
+ An optional override for the dimension of the projected time embedding.
155
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
156
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
157
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
158
+ timestep_post_act (`str`, *optional*, defaults to `None`):
159
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
160
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
161
+ The dimension of `cond_proj` layer in the timestep embedding.
162
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
163
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
164
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
165
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
166
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
167
+ embeddings with the class embeddings.
168
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
169
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
170
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
171
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
172
+ otherwise.
173
+ """
174
+
175
+ _supports_gradient_checkpointing = True
176
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
177
+
178
+ @register_to_config
179
+ def __init__(
180
+ self,
181
+ sample_size: Optional[int] = None,
182
+ in_channels: int = 4,
183
+ out_channels: int = 4,
184
+ center_input_sample: bool = False,
185
+ flip_sin_to_cos: bool = True,
186
+ freq_shift: int = 0,
187
+ down_block_types: Tuple[str] = (
188
+ "CrossAttnDownBlock2D",
189
+ "CrossAttnDownBlock2D",
190
+ "CrossAttnDownBlock2D",
191
+ "DownBlock2D",
192
+ ),
193
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
194
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
195
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
196
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
197
+ layers_per_block: Union[int, Tuple[int]] = 2,
198
+ downsample_padding: int = 1,
199
+ mid_block_scale_factor: float = 1,
200
+ dropout: float = 0.0,
201
+ act_fn: str = "silu",
202
+ norm_num_groups: Optional[int] = 32,
203
+ norm_eps: float = 1e-5,
204
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
205
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
206
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
207
+ encoder_hid_dim: Optional[int] = None,
208
+ encoder_hid_dim_type: Optional[str] = None,
209
+ attention_head_dim: Union[int, Tuple[int]] = 8,
210
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
211
+ dual_cross_attention: bool = False,
212
+ use_linear_projection: bool = False,
213
+ class_embed_type: Optional[str] = None,
214
+ addition_embed_type: Optional[str] = None,
215
+ addition_time_embed_dim: Optional[int] = None,
216
+ num_class_embeds: Optional[int] = None,
217
+ upcast_attention: bool = False,
218
+ resnet_time_scale_shift: str = "default",
219
+ resnet_skip_time_act: bool = False,
220
+ resnet_out_scale_factor: float = 1.0,
221
+ time_embedding_type: str = "positional",
222
+ time_embedding_dim: Optional[int] = None,
223
+ time_embedding_act_fn: Optional[str] = None,
224
+ timestep_post_act: Optional[str] = None,
225
+ time_cond_proj_dim: Optional[int] = None,
226
+ conv_in_kernel: int = 3,
227
+ conv_out_kernel: int = 3,
228
+ projection_class_embeddings_input_dim: Optional[int] = None,
229
+ attention_type: str = "default",
230
+ class_embeddings_concat: bool = False,
231
+ mid_block_only_cross_attention: Optional[bool] = None,
232
+ cross_attention_norm: Optional[str] = None,
233
+ addition_embed_type_num_heads: int = 64,
234
+ ):
235
+ super().__init__()
236
+
237
+ self.sample_size = sample_size
238
+
239
+ if num_attention_heads is not None:
240
+ raise ValueError(
241
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
242
+ )
243
+
244
+ # If `num_attention_heads` is not defined (which is the case for most models)
245
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
246
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
247
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
248
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
249
+ # which is why we correct for the naming here.
250
+ num_attention_heads = num_attention_heads or attention_head_dim
251
+
252
+ # Check inputs
253
+ self._check_config(
254
+ down_block_types=down_block_types,
255
+ up_block_types=up_block_types,
256
+ only_cross_attention=only_cross_attention,
257
+ block_out_channels=block_out_channels,
258
+ layers_per_block=layers_per_block,
259
+ cross_attention_dim=cross_attention_dim,
260
+ transformer_layers_per_block=transformer_layers_per_block,
261
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
262
+ attention_head_dim=attention_head_dim,
263
+ num_attention_heads=num_attention_heads,
264
+ )
265
+
266
+ # input
267
+ conv_in_padding = (conv_in_kernel - 1) // 2
268
+ self.conv_in = nn.Conv2d(
269
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
270
+ )
271
+
272
+ # time
273
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
274
+ time_embedding_type,
275
+ block_out_channels=block_out_channels,
276
+ flip_sin_to_cos=flip_sin_to_cos,
277
+ freq_shift=freq_shift,
278
+ time_embedding_dim=time_embedding_dim,
279
+ )
280
+
281
+ self.time_embedding = TimestepEmbedding(
282
+ timestep_input_dim,
283
+ time_embed_dim,
284
+ act_fn=act_fn,
285
+ post_act_fn=timestep_post_act,
286
+ cond_proj_dim=time_cond_proj_dim,
287
+ )
288
+
289
+ self._set_encoder_hid_proj(
290
+ encoder_hid_dim_type,
291
+ cross_attention_dim=cross_attention_dim,
292
+ encoder_hid_dim=encoder_hid_dim,
293
+ )
294
+
295
+ # class embedding
296
+ self._set_class_embedding(
297
+ class_embed_type,
298
+ act_fn=act_fn,
299
+ num_class_embeds=num_class_embeds,
300
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
301
+ time_embed_dim=time_embed_dim,
302
+ timestep_input_dim=timestep_input_dim,
303
+ )
304
+
305
+ self._set_add_embedding(
306
+ addition_embed_type,
307
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
308
+ addition_time_embed_dim=addition_time_embed_dim,
309
+ cross_attention_dim=cross_attention_dim,
310
+ encoder_hid_dim=encoder_hid_dim,
311
+ flip_sin_to_cos=flip_sin_to_cos,
312
+ freq_shift=freq_shift,
313
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
314
+ time_embed_dim=time_embed_dim,
315
+ )
316
+
317
+ if time_embedding_act_fn is None:
318
+ self.time_embed_act = None
319
+ else:
320
+ self.time_embed_act = get_activation(time_embedding_act_fn)
321
+
322
+ self.down_blocks = nn.ModuleList([])
323
+ self.up_blocks = nn.ModuleList([])
324
+
325
+ if isinstance(only_cross_attention, bool):
326
+ if mid_block_only_cross_attention is None:
327
+ mid_block_only_cross_attention = only_cross_attention
328
+
329
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
330
+
331
+ if mid_block_only_cross_attention is None:
332
+ mid_block_only_cross_attention = False
333
+
334
+ if isinstance(num_attention_heads, int):
335
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
336
+
337
+ if isinstance(attention_head_dim, int):
338
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
339
+
340
+ if isinstance(cross_attention_dim, int):
341
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
342
+
343
+ if isinstance(layers_per_block, int):
344
+ layers_per_block = [layers_per_block] * len(down_block_types)
345
+
346
+ if isinstance(transformer_layers_per_block, int):
347
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
348
+
349
+ if class_embeddings_concat:
350
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
351
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
352
+ # regular time embeddings
353
+ blocks_time_embed_dim = time_embed_dim * 2
354
+ else:
355
+ blocks_time_embed_dim = time_embed_dim
356
+
357
+ # down
358
+ output_channel = block_out_channels[0]
359
+ for i, down_block_type in enumerate(down_block_types):
360
+ input_channel = output_channel
361
+ output_channel = block_out_channels[i]
362
+ is_final_block = i == len(block_out_channels) - 1
363
+
364
+ down_block = get_down_block(
365
+ down_block_type,
366
+ num_layers=layers_per_block[i],
367
+ transformer_layers_per_block=transformer_layers_per_block[i],
368
+ in_channels=input_channel,
369
+ out_channels=output_channel,
370
+ temb_channels=blocks_time_embed_dim,
371
+ add_downsample=not is_final_block,
372
+ resnet_eps=norm_eps,
373
+ resnet_act_fn=act_fn,
374
+ resnet_groups=norm_num_groups,
375
+ cross_attention_dim=cross_attention_dim[i],
376
+ num_attention_heads=num_attention_heads[i],
377
+ downsample_padding=downsample_padding,
378
+ dual_cross_attention=dual_cross_attention,
379
+ use_linear_projection=use_linear_projection,
380
+ only_cross_attention=only_cross_attention[i],
381
+ upcast_attention=upcast_attention,
382
+ resnet_time_scale_shift=resnet_time_scale_shift,
383
+ attention_type=attention_type,
384
+ resnet_skip_time_act=resnet_skip_time_act,
385
+ resnet_out_scale_factor=resnet_out_scale_factor,
386
+ cross_attention_norm=cross_attention_norm,
387
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
388
+ dropout=dropout,
389
+ )
390
+ self.down_blocks.append(down_block)
391
+
392
+ # mid
393
+ self.mid_block = get_mid_block(
394
+ mid_block_type,
395
+ temb_channels=blocks_time_embed_dim,
396
+ in_channels=block_out_channels[-1],
397
+ resnet_eps=norm_eps,
398
+ resnet_act_fn=act_fn,
399
+ resnet_groups=norm_num_groups,
400
+ output_scale_factor=mid_block_scale_factor,
401
+ transformer_layers_per_block=transformer_layers_per_block[-1],
402
+ num_attention_heads=num_attention_heads[-1],
403
+ cross_attention_dim=cross_attention_dim[-1],
404
+ dual_cross_attention=dual_cross_attention,
405
+ use_linear_projection=use_linear_projection,
406
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
407
+ upcast_attention=upcast_attention,
408
+ resnet_time_scale_shift=resnet_time_scale_shift,
409
+ attention_type=attention_type,
410
+ resnet_skip_time_act=resnet_skip_time_act,
411
+ cross_attention_norm=cross_attention_norm,
412
+ attention_head_dim=attention_head_dim[-1],
413
+ dropout=dropout,
414
+ )
415
+
416
+ # count how many layers upsample the images
417
+ self.num_upsamplers = 0
418
+
419
+ # up
420
+ reversed_block_out_channels = list(reversed(block_out_channels))
421
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
422
+ reversed_layers_per_block = list(reversed(layers_per_block))
423
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
424
+ reversed_transformer_layers_per_block = (
425
+ list(reversed(transformer_layers_per_block))
426
+ if reverse_transformer_layers_per_block is None
427
+ else reverse_transformer_layers_per_block
428
+ )
429
+ only_cross_attention = list(reversed(only_cross_attention))
430
+
431
+ output_channel = reversed_block_out_channels[0]
432
+ for i, up_block_type in enumerate(up_block_types):
433
+ is_final_block = i == len(block_out_channels) - 1
434
+
435
+ prev_output_channel = output_channel
436
+ output_channel = reversed_block_out_channels[i]
437
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
438
+
439
+ # add upsample block for all BUT final layer
440
+ if not is_final_block:
441
+ add_upsample = True
442
+ self.num_upsamplers += 1
443
+ else:
444
+ add_upsample = False
445
+
446
+ up_block = get_up_block(
447
+ up_block_type,
448
+ num_layers=reversed_layers_per_block[i] + 1,
449
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
450
+ in_channels=input_channel,
451
+ out_channels=output_channel,
452
+ prev_output_channel=prev_output_channel,
453
+ temb_channels=blocks_time_embed_dim,
454
+ add_upsample=add_upsample,
455
+ resnet_eps=norm_eps,
456
+ resnet_act_fn=act_fn,
457
+ resolution_idx=i,
458
+ resnet_groups=norm_num_groups,
459
+ cross_attention_dim=reversed_cross_attention_dim[i],
460
+ num_attention_heads=reversed_num_attention_heads[i],
461
+ dual_cross_attention=dual_cross_attention,
462
+ use_linear_projection=use_linear_projection,
463
+ only_cross_attention=only_cross_attention[i],
464
+ upcast_attention=upcast_attention,
465
+ resnet_time_scale_shift=resnet_time_scale_shift,
466
+ attention_type=attention_type,
467
+ resnet_skip_time_act=resnet_skip_time_act,
468
+ resnet_out_scale_factor=resnet_out_scale_factor,
469
+ cross_attention_norm=cross_attention_norm,
470
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
471
+ dropout=dropout,
472
+ )
473
+ self.up_blocks.append(up_block)
474
+ prev_output_channel = output_channel
475
+
476
+ # out
477
+ if norm_num_groups is not None:
478
+ self.conv_norm_out = nn.GroupNorm(
479
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
480
+ )
481
+
482
+ self.conv_act = get_activation(act_fn)
483
+
484
+ else:
485
+ self.conv_norm_out = None
486
+ self.conv_act = None
487
+
488
+ conv_out_padding = (conv_out_kernel - 1) // 2
489
+ self.conv_out = nn.Conv2d(
490
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
491
+ )
492
+
493
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
494
+
495
+ def _check_config(
496
+ self,
497
+ down_block_types: Tuple[str],
498
+ up_block_types: Tuple[str],
499
+ only_cross_attention: Union[bool, Tuple[bool]],
500
+ block_out_channels: Tuple[int],
501
+ layers_per_block: Union[int, Tuple[int]],
502
+ cross_attention_dim: Union[int, Tuple[int]],
503
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
504
+ reverse_transformer_layers_per_block: bool,
505
+ attention_head_dim: int,
506
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
507
+ ):
508
+ if len(down_block_types) != len(up_block_types):
509
+ raise ValueError(
510
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
511
+ )
512
+
513
+ if len(block_out_channels) != len(down_block_types):
514
+ raise ValueError(
515
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
516
+ )
517
+
518
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
519
+ raise ValueError(
520
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
521
+ )
522
+
523
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
524
+ raise ValueError(
525
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
526
+ )
527
+
528
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
529
+ raise ValueError(
530
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
531
+ )
532
+
533
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
534
+ raise ValueError(
535
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
536
+ )
537
+
538
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
539
+ raise ValueError(
540
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
541
+ )
542
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
543
+ for layer_number_per_block in transformer_layers_per_block:
544
+ if isinstance(layer_number_per_block, list):
545
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
546
+
547
+ def _set_time_proj(
548
+ self,
549
+ time_embedding_type: str,
550
+ block_out_channels: int,
551
+ flip_sin_to_cos: bool,
552
+ freq_shift: float,
553
+ time_embedding_dim: int,
554
+ ) -> Tuple[int, int]:
555
+ if time_embedding_type == "fourier":
556
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
557
+ if time_embed_dim % 2 != 0:
558
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
559
+ self.time_proj = GaussianFourierProjection(
560
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
561
+ )
562
+ timestep_input_dim = time_embed_dim
563
+ elif time_embedding_type == "positional":
564
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
565
+
566
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
567
+ timestep_input_dim = block_out_channels[0]
568
+ else:
569
+ raise ValueError(
570
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
571
+ )
572
+
573
+ return time_embed_dim, timestep_input_dim
574
+
575
+ def _set_encoder_hid_proj(
576
+ self,
577
+ encoder_hid_dim_type: Optional[str],
578
+ cross_attention_dim: Union[int, Tuple[int]],
579
+ encoder_hid_dim: Optional[int],
580
+ ):
581
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
582
+ encoder_hid_dim_type = "text_proj"
583
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
584
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
585
+
586
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
587
+ raise ValueError(
588
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
589
+ )
590
+
591
+ if encoder_hid_dim_type == "text_proj":
592
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
593
+ elif encoder_hid_dim_type == "text_image_proj":
594
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
595
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
596
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
597
+ self.encoder_hid_proj = TextImageProjection(
598
+ text_embed_dim=encoder_hid_dim,
599
+ image_embed_dim=cross_attention_dim,
600
+ cross_attention_dim=cross_attention_dim,
601
+ )
602
+ elif encoder_hid_dim_type == "image_proj":
603
+ # Kandinsky 2.2
604
+ self.encoder_hid_proj = ImageProjection(
605
+ image_embed_dim=encoder_hid_dim,
606
+ cross_attention_dim=cross_attention_dim,
607
+ )
608
+ elif encoder_hid_dim_type is not None:
609
+ raise ValueError(
610
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
611
+ )
612
+ else:
613
+ self.encoder_hid_proj = None
614
+
615
+ def _set_class_embedding(
616
+ self,
617
+ class_embed_type: Optional[str],
618
+ act_fn: str,
619
+ num_class_embeds: Optional[int],
620
+ projection_class_embeddings_input_dim: Optional[int],
621
+ time_embed_dim: int,
622
+ timestep_input_dim: int,
623
+ ):
624
+ if class_embed_type is None and num_class_embeds is not None:
625
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
626
+ elif class_embed_type == "timestep":
627
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
628
+ elif class_embed_type == "identity":
629
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
630
+ elif class_embed_type == "projection":
631
+ if projection_class_embeddings_input_dim is None:
632
+ raise ValueError(
633
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
634
+ )
635
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
636
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
637
+ # 2. it projects from an arbitrary input dimension.
638
+ #
639
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
640
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
641
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
642
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
643
+ elif class_embed_type == "simple_projection":
644
+ if projection_class_embeddings_input_dim is None:
645
+ raise ValueError(
646
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
647
+ )
648
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
649
+ else:
650
+ self.class_embedding = None
651
+
652
+ def _set_add_embedding(
653
+ self,
654
+ addition_embed_type: str,
655
+ addition_embed_type_num_heads: int,
656
+ addition_time_embed_dim: Optional[int],
657
+ flip_sin_to_cos: bool,
658
+ freq_shift: float,
659
+ cross_attention_dim: Optional[int],
660
+ encoder_hid_dim: Optional[int],
661
+ projection_class_embeddings_input_dim: Optional[int],
662
+ time_embed_dim: int,
663
+ ):
664
+ if addition_embed_type == "text":
665
+ if encoder_hid_dim is not None:
666
+ text_time_embedding_from_dim = encoder_hid_dim
667
+ else:
668
+ text_time_embedding_from_dim = cross_attention_dim
669
+
670
+ self.add_embedding = TextTimeEmbedding(
671
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
672
+ )
673
+ elif addition_embed_type == "text_image":
674
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
675
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
676
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
677
+ self.add_embedding = TextImageTimeEmbedding(
678
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
679
+ )
680
+ elif addition_embed_type == "text_time":
681
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
682
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
683
+ elif addition_embed_type == "image":
684
+ # Kandinsky 2.2
685
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
686
+ elif addition_embed_type == "image_hint":
687
+ # Kandinsky 2.2 ControlNet
688
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
689
+ elif addition_embed_type is not None:
690
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
691
+
692
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
693
+ if attention_type in ["gated", "gated-text-image"]:
694
+ positive_len = 768
695
+ if isinstance(cross_attention_dim, int):
696
+ positive_len = cross_attention_dim
697
+ elif isinstance(cross_attention_dim, (list, tuple)):
698
+ positive_len = cross_attention_dim[0]
699
+
700
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
701
+ self.position_net = GLIGENTextBoundingboxProjection(
702
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
703
+ )
704
+
705
+ @property
706
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
707
+ r"""
708
+ Returns:
709
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
710
+ indexed by its weight name.
711
+ """
712
+ # set recursively
713
+ processors = {}
714
+
715
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
716
+ if hasattr(module, "get_processor"):
717
+ processors[f"{name}.processor"] = module.get_processor()
718
+
719
+ for sub_name, child in module.named_children():
720
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
721
+
722
+ return processors
723
+
724
+ for name, module in self.named_children():
725
+ fn_recursive_add_processors(name, module, processors)
726
+
727
+ return processors
728
+
729
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
730
+ r"""
731
+ Sets the attention processor to use to compute attention.
732
+
733
+ Parameters:
734
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
735
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
736
+ for **all** `Attention` layers.
737
+
738
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
739
+ processor. This is strongly recommended when setting trainable attention processors.
740
+
741
+ """
742
+ count = len(self.attn_processors.keys())
743
+
744
+ if isinstance(processor, dict) and len(processor) != count:
745
+ raise ValueError(
746
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
747
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
748
+ )
749
+
750
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
751
+ if hasattr(module, "set_processor"):
752
+ if not isinstance(processor, dict):
753
+ module.set_processor(processor)
754
+ else:
755
+ module.set_processor(processor.pop(f"{name}.processor"))
756
+
757
+ for sub_name, child in module.named_children():
758
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
759
+
760
+ for name, module in self.named_children():
761
+ fn_recursive_attn_processor(name, module, processor)
762
+
763
+ def set_default_attn_processor(self):
764
+ """
765
+ Disables custom attention processors and sets the default attention implementation.
766
+ """
767
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
768
+ processor = AttnAddedKVProcessor()
769
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
770
+ processor = AttnProcessor()
771
+ else:
772
+ raise ValueError(
773
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
774
+ )
775
+
776
+ self.set_attn_processor(processor)
777
+
778
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
779
+ r"""
780
+ Enable sliced attention computation.
781
+
782
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
783
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
784
+
785
+ Args:
786
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
787
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
788
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
789
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
790
+ must be a multiple of `slice_size`.
791
+ """
792
+ sliceable_head_dims = []
793
+
794
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
795
+ if hasattr(module, "set_attention_slice"):
796
+ sliceable_head_dims.append(module.sliceable_head_dim)
797
+
798
+ for child in module.children():
799
+ fn_recursive_retrieve_sliceable_dims(child)
800
+
801
+ # retrieve number of attention layers
802
+ for module in self.children():
803
+ fn_recursive_retrieve_sliceable_dims(module)
804
+
805
+ num_sliceable_layers = len(sliceable_head_dims)
806
+
807
+ if slice_size == "auto":
808
+ # half the attention head size is usually a good trade-off between
809
+ # speed and memory
810
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
811
+ elif slice_size == "max":
812
+ # make smallest slice possible
813
+ slice_size = num_sliceable_layers * [1]
814
+
815
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
816
+
817
+ if len(slice_size) != len(sliceable_head_dims):
818
+ raise ValueError(
819
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
820
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
821
+ )
822
+
823
+ for i in range(len(slice_size)):
824
+ size = slice_size[i]
825
+ dim = sliceable_head_dims[i]
826
+ if size is not None and size > dim:
827
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
828
+
829
+ # Recursively walk through all the children.
830
+ # Any children which exposes the set_attention_slice method
831
+ # gets the message
832
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
833
+ if hasattr(module, "set_attention_slice"):
834
+ module.set_attention_slice(slice_size.pop())
835
+
836
+ for child in module.children():
837
+ fn_recursive_set_attention_slice(child, slice_size)
838
+
839
+ reversed_slice_size = list(reversed(slice_size))
840
+ for module in self.children():
841
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
842
+
843
+ def _set_gradient_checkpointing(self, module, value=False):
844
+ if hasattr(module, "gradient_checkpointing"):
845
+ module.gradient_checkpointing = value
846
+
847
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
848
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
849
+
850
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
851
+
852
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
853
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
854
+
855
+ Args:
856
+ s1 (`float`):
857
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
858
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
859
+ s2 (`float`):
860
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
861
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
862
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
863
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
864
+ """
865
+ for i, upsample_block in enumerate(self.up_blocks):
866
+ setattr(upsample_block, "s1", s1)
867
+ setattr(upsample_block, "s2", s2)
868
+ setattr(upsample_block, "b1", b1)
869
+ setattr(upsample_block, "b2", b2)
870
+
871
+ def disable_freeu(self):
872
+ """Disables the FreeU mechanism."""
873
+ freeu_keys = {"s1", "s2", "b1", "b2"}
874
+ for i, upsample_block in enumerate(self.up_blocks):
875
+ for k in freeu_keys:
876
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
877
+ setattr(upsample_block, k, None)
878
+
879
+ def fuse_qkv_projections(self):
880
+ """
881
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
882
+ are fused. For cross-attention modules, key and value projection matrices are fused.
883
+
884
+ <Tip warning={true}>
885
+
886
+ This API is 🧪 experimental.
887
+
888
+ </Tip>
889
+ """
890
+ self.original_attn_processors = None
891
+
892
+ for _, attn_processor in self.attn_processors.items():
893
+ if "Added" in str(attn_processor.__class__.__name__):
894
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
895
+
896
+ self.original_attn_processors = self.attn_processors
897
+
898
+ for module in self.modules():
899
+ if isinstance(module, Attention):
900
+ module.fuse_projections(fuse=True)
901
+
902
+ def unfuse_qkv_projections(self):
903
+ """Disables the fused QKV projection if enabled.
904
+
905
+ <Tip warning={true}>
906
+
907
+ This API is 🧪 experimental.
908
+
909
+ </Tip>
910
+
911
+ """
912
+ if self.original_attn_processors is not None:
913
+ self.set_attn_processor(self.original_attn_processors)
914
+
915
+ def get_time_embed(
916
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
917
+ ) -> Optional[torch.Tensor]:
918
+ timesteps = timestep
919
+ if not torch.is_tensor(timesteps):
920
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
921
+ # This would be a good case for the `match` statement (Python 3.10+)
922
+ is_mps = sample.device.type == "mps"
923
+ if isinstance(timestep, float):
924
+ dtype = torch.float32 if is_mps else torch.float64
925
+ else:
926
+ dtype = torch.int32 if is_mps else torch.int64
927
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
928
+ elif len(timesteps.shape) == 0:
929
+ timesteps = timesteps[None].to(sample.device)
930
+
931
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
932
+ timesteps = timesteps.expand(sample.shape[0])
933
+
934
+ t_emb = self.time_proj(timesteps)
935
+ # `Timesteps` does not contain any weights and will always return f32 tensors
936
+ # but time_embedding might actually be running in fp16. so we need to cast here.
937
+ # there might be better ways to encapsulate this.
938
+ t_emb = t_emb.to(dtype=sample.dtype)
939
+ return t_emb
940
+
941
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
942
+ class_emb = None
943
+ if self.class_embedding is not None:
944
+ if class_labels is None:
945
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
946
+
947
+ if self.config.class_embed_type == "timestep":
948
+ class_labels = self.time_proj(class_labels)
949
+
950
+ # `Timesteps` does not contain any weights and will always return f32 tensors
951
+ # there might be better ways to encapsulate this.
952
+ class_labels = class_labels.to(dtype=sample.dtype)
953
+
954
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
955
+ return class_emb
956
+
957
+ def get_aug_embed(
958
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
959
+ ) -> Optional[torch.Tensor]:
960
+ aug_emb = None
961
+ if self.config.addition_embed_type == "text":
962
+ aug_emb = self.add_embedding(encoder_hidden_states)
963
+ elif self.config.addition_embed_type == "text_image":
964
+ # Kandinsky 2.1 - style
965
+ if "image_embeds" not in added_cond_kwargs:
966
+ raise ValueError(
967
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
968
+ )
969
+
970
+ image_embs = added_cond_kwargs.get("image_embeds")
971
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
972
+ aug_emb = self.add_embedding(text_embs, image_embs)
973
+ elif self.config.addition_embed_type == "text_time":
974
+ # SDXL - style
975
+ if "text_embeds" not in added_cond_kwargs:
976
+ raise ValueError(
977
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
978
+ )
979
+ text_embeds = added_cond_kwargs.get("text_embeds")
980
+ if "time_ids" not in added_cond_kwargs:
981
+ raise ValueError(
982
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
983
+ )
984
+ time_ids = added_cond_kwargs.get("time_ids")
985
+ time_embeds = self.add_time_proj(time_ids.flatten())
986
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
987
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
988
+ add_embeds = add_embeds.to(emb.dtype)
989
+ aug_emb = self.add_embedding(add_embeds)
990
+ elif self.config.addition_embed_type == "image":
991
+ # Kandinsky 2.2 - style
992
+ if "image_embeds" not in added_cond_kwargs:
993
+ raise ValueError(
994
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
995
+ )
996
+ image_embs = added_cond_kwargs.get("image_embeds")
997
+ aug_emb = self.add_embedding(image_embs)
998
+ elif self.config.addition_embed_type == "image_hint":
999
+ # Kandinsky 2.2 - style
1000
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1001
+ raise ValueError(
1002
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1003
+ )
1004
+ image_embs = added_cond_kwargs.get("image_embeds")
1005
+ hint = added_cond_kwargs.get("hint")
1006
+ aug_emb = self.add_embedding(image_embs, hint)
1007
+ return aug_emb
1008
+
1009
+ def process_encoder_hidden_states(
1010
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1011
+ ) -> torch.Tensor:
1012
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1013
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1014
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1015
+ # Kandinsky 2.1 - style
1016
+ if "image_embeds" not in added_cond_kwargs:
1017
+ raise ValueError(
1018
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1019
+ )
1020
+
1021
+ image_embeds = added_cond_kwargs.get("image_embeds")
1022
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1023
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1024
+ # Kandinsky 2.2 - style
1025
+ if "image_embeds" not in added_cond_kwargs:
1026
+ raise ValueError(
1027
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1028
+ )
1029
+ image_embeds = added_cond_kwargs.get("image_embeds")
1030
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1031
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1032
+ if "image_embeds" not in added_cond_kwargs:
1033
+ raise ValueError(
1034
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1035
+ )
1036
+
1037
+ if hasattr(self, 'text_encoder_hid_proj') and not self.text_encoder_hid_proj is None:
1038
+ encoder_hidden_states = self.text_encoder_hid_proj( encoder_hidden_states )
1039
+
1040
+ image_embeds = added_cond_kwargs.get("image_embeds")
1041
+ image_embeds = self.encoder_hid_proj(image_embeds)
1042
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1043
+ return encoder_hidden_states
1044
+
1045
+ def forward(
1046
+ self,
1047
+ sample: torch.Tensor,
1048
+ timestep: Union[torch.Tensor, float, int],
1049
+ encoder_hidden_states: torch.Tensor,
1050
+ class_labels: Optional[torch.Tensor] = None,
1051
+ timestep_cond: Optional[torch.Tensor] = None,
1052
+ attention_mask: Optional[torch.Tensor] = None,
1053
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1054
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1055
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1056
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1057
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1058
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1059
+ return_dict: bool = True,
1060
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1061
+ r"""
1062
+ The [`UNet2DConditionModel`] forward method.
1063
+
1064
+ Args:
1065
+ sample (`torch.Tensor`):
1066
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1067
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1068
+ encoder_hidden_states (`torch.Tensor`):
1069
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1070
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1071
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1072
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1073
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1074
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1075
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1076
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1077
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1078
+ negative values to the attention scores corresponding to "discard" tokens.
1079
+ cross_attention_kwargs (`dict`, *optional*):
1080
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1081
+ `self.processor` in
1082
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1083
+ added_cond_kwargs: (`dict`, *optional*):
1084
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1085
+ are passed along to the UNet blocks.
1086
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1087
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1088
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1089
+ A tensor that if specified is added to the residual of the middle unet block.
1090
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1091
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1092
+ encoder_attention_mask (`torch.Tensor`):
1093
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1094
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1095
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1096
+ return_dict (`bool`, *optional*, defaults to `True`):
1097
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1098
+ tuple.
1099
+
1100
+ Returns:
1101
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1102
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1103
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1104
+ """
1105
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1106
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1107
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1108
+ # on the fly if necessary.
1109
+ default_overall_up_factor = 2**self.num_upsamplers
1110
+
1111
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1112
+ forward_upsample_size = False
1113
+ upsample_size = None
1114
+
1115
+ for dim in sample.shape[-2:]:
1116
+ if dim % default_overall_up_factor != 0:
1117
+ # Forward upsample size to force interpolation output size.
1118
+ forward_upsample_size = True
1119
+ break
1120
+
1121
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1122
+ # expects mask of shape:
1123
+ # [batch, key_tokens]
1124
+ # adds singleton query_tokens dimension:
1125
+ # [batch, 1, key_tokens]
1126
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1127
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1128
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1129
+ if attention_mask is not None:
1130
+ # assume that mask is expressed as:
1131
+ # (1 = keep, 0 = discard)
1132
+ # convert mask into a bias that can be added to attention scores:
1133
+ # (keep = +0, discard = -10000.0)
1134
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1135
+ attention_mask = attention_mask.unsqueeze(1)
1136
+
1137
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1138
+ if encoder_attention_mask is not None:
1139
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1140
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1141
+
1142
+ # 0. center input if necessary
1143
+ if self.config.center_input_sample:
1144
+ sample = 2 * sample - 1.0
1145
+
1146
+ # 1. time
1147
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1148
+ emb = self.time_embedding(t_emb, timestep_cond)
1149
+ aug_emb = None
1150
+
1151
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1152
+ if class_emb is not None:
1153
+ if self.config.class_embeddings_concat:
1154
+ emb = torch.cat([emb, class_emb], dim=-1)
1155
+ else:
1156
+ emb = emb + class_emb
1157
+
1158
+ aug_emb = self.get_aug_embed(
1159
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1160
+ )
1161
+ if self.config.addition_embed_type == "image_hint":
1162
+ aug_emb, hint = aug_emb
1163
+ sample = torch.cat([sample, hint], dim=1)
1164
+
1165
+ emb = emb + aug_emb if aug_emb is not None else emb
1166
+
1167
+ if self.time_embed_act is not None:
1168
+ emb = self.time_embed_act(emb)
1169
+
1170
+ encoder_hidden_states = self.process_encoder_hidden_states(
1171
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1172
+ )
1173
+
1174
+ # 2. pre-process
1175
+ sample = self.conv_in(sample)
1176
+
1177
+ # 2.5 GLIGEN position net
1178
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1179
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1180
+ gligen_args = cross_attention_kwargs.pop("gligen")
1181
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1182
+
1183
+ # 3. down
1184
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1185
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1186
+ if cross_attention_kwargs is not None:
1187
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1188
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1189
+ else:
1190
+ lora_scale = 1.0
1191
+
1192
+ if USE_PEFT_BACKEND:
1193
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1194
+ scale_lora_layers(self, lora_scale)
1195
+
1196
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1197
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1198
+ is_adapter = down_intrablock_additional_residuals is not None
1199
+ # maintain backward compatibility for legacy usage, where
1200
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1201
+ # but can only use one or the other
1202
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1203
+ deprecate(
1204
+ "T2I should not use down_block_additional_residuals",
1205
+ "1.3.0",
1206
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1207
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1208
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1209
+ standard_warn=False,
1210
+ )
1211
+ down_intrablock_additional_residuals = down_block_additional_residuals
1212
+ is_adapter = True
1213
+
1214
+ down_block_res_samples = (sample,)
1215
+ for downsample_block in self.down_blocks:
1216
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1217
+ # For t2i-adapter CrossAttnDownBlock2D
1218
+ additional_residuals = {}
1219
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1220
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1221
+
1222
+ sample, res_samples = downsample_block(
1223
+ hidden_states=sample,
1224
+ temb=emb,
1225
+ encoder_hidden_states=encoder_hidden_states,
1226
+ attention_mask=attention_mask,
1227
+ cross_attention_kwargs=cross_attention_kwargs,
1228
+ encoder_attention_mask=encoder_attention_mask,
1229
+ **additional_residuals,
1230
+ )
1231
+ else:
1232
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1233
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1234
+ sample += down_intrablock_additional_residuals.pop(0)
1235
+
1236
+ down_block_res_samples += res_samples
1237
+
1238
+ if is_controlnet:
1239
+ new_down_block_res_samples = ()
1240
+
1241
+ for down_block_res_sample, down_block_additional_residual in zip(
1242
+ down_block_res_samples, down_block_additional_residuals
1243
+ ):
1244
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1245
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1246
+
1247
+ down_block_res_samples = new_down_block_res_samples
1248
+
1249
+ # 4. mid
1250
+ if self.mid_block is not None:
1251
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1252
+ sample = self.mid_block(
1253
+ sample,
1254
+ emb,
1255
+ encoder_hidden_states=encoder_hidden_states,
1256
+ attention_mask=attention_mask,
1257
+ cross_attention_kwargs=cross_attention_kwargs,
1258
+ encoder_attention_mask=encoder_attention_mask,
1259
+ )
1260
+ else:
1261
+ sample = self.mid_block(sample, emb)
1262
+
1263
+ # To support T2I-Adapter-XL
1264
+ if (
1265
+ is_adapter
1266
+ and len(down_intrablock_additional_residuals) > 0
1267
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1268
+ ):
1269
+ sample += down_intrablock_additional_residuals.pop(0)
1270
+
1271
+ if is_controlnet:
1272
+ sample = sample + mid_block_additional_residual
1273
+
1274
+ # 5. up
1275
+ for i, upsample_block in enumerate(self.up_blocks):
1276
+ is_final_block = i == len(self.up_blocks) - 1
1277
+
1278
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1279
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1280
+
1281
+ # if we have not reached the final block and need to forward the
1282
+ # upsample size, we do it here
1283
+ if not is_final_block and forward_upsample_size:
1284
+ upsample_size = down_block_res_samples[-1].shape[2:]
1285
+
1286
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1287
+ sample = upsample_block(
1288
+ hidden_states=sample,
1289
+ temb=emb,
1290
+ res_hidden_states_tuple=res_samples,
1291
+ encoder_hidden_states=encoder_hidden_states,
1292
+ cross_attention_kwargs=cross_attention_kwargs,
1293
+ upsample_size=upsample_size,
1294
+ attention_mask=attention_mask,
1295
+ encoder_attention_mask=encoder_attention_mask,
1296
+ )
1297
+ else:
1298
+ sample = upsample_block(
1299
+ hidden_states=sample,
1300
+ temb=emb,
1301
+ res_hidden_states_tuple=res_samples,
1302
+ upsample_size=upsample_size,
1303
+ )
1304
+
1305
+ # 6. post-process
1306
+ if self.conv_norm_out:
1307
+ sample = self.conv_norm_out(sample)
1308
+ sample = self.conv_act(sample)
1309
+ sample = self.conv_out(sample)
1310
+
1311
+ if USE_PEFT_BACKEND:
1312
+ # remove `lora_scale` from each PEFT layer
1313
+ unscale_lora_layers(self, lora_scale)
1314
+
1315
+ if not return_dict:
1316
+ return (sample,)
1317
+
1318
+ return UNet2DConditionOutput(sample=sample)