SunderAli17 commited on
Commit
0fca383
1 Parent(s): 2e62d1b

Create attention_processor.py

Browse files
Files changed (1) hide show
  1. toonmage/attention_processor.py +422 -0
toonmage/attention_processor.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ NUM_ZERO = 0
7
+ ORTHO = False
8
+ ORTHO_v2 = False
9
+
10
+
11
+ class AttnProcessor(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def __call__(
16
+ self,
17
+ attn,
18
+ hidden_states,
19
+ encoder_hidden_states=None,
20
+ attention_mask=None,
21
+ temb=None,
22
+ id_embedding=None,
23
+ id_scale=1.0,
24
+ ):
25
+ residual = hidden_states
26
+
27
+ if attn.spatial_norm is not None:
28
+ hidden_states = attn.spatial_norm(hidden_states, temb)
29
+
30
+ input_ndim = hidden_states.ndim
31
+
32
+ if input_ndim == 4:
33
+ batch_size, channel, height, width = hidden_states.shape
34
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
35
+
36
+ batch_size, sequence_length, _ = (
37
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
38
+ )
39
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
40
+
41
+ if attn.group_norm is not None:
42
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
43
+
44
+ query = attn.to_q(hidden_states)
45
+
46
+ if encoder_hidden_states is None:
47
+ encoder_hidden_states = hidden_states
48
+ elif attn.norm_cross:
49
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
50
+
51
+ key = attn.to_k(encoder_hidden_states)
52
+ value = attn.to_v(encoder_hidden_states)
53
+
54
+ query = attn.head_to_batch_dim(query)
55
+ key = attn.head_to_batch_dim(key)
56
+ value = attn.head_to_batch_dim(value)
57
+
58
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
59
+ hidden_states = torch.bmm(attention_probs, value)
60
+ hidden_states = attn.batch_to_head_dim(hidden_states)
61
+
62
+ # linear proj
63
+ hidden_states = attn.to_out[0](hidden_states)
64
+ # dropout
65
+ hidden_states = attn.to_out[1](hidden_states)
66
+
67
+ if input_ndim == 4:
68
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
69
+
70
+ if attn.residual_connection:
71
+ hidden_states = hidden_states + residual
72
+
73
+ hidden_states = hidden_states / attn.rescale_output_factor
74
+
75
+ return hidden_states
76
+
77
+
78
+ class IDAttnProcessor(nn.Module):
79
+ r"""
80
+ Attention processor for ID-Adapater.
81
+ Args:
82
+ hidden_size (`int`):
83
+ The hidden size of the attention layer.
84
+ cross_attention_dim (`int`):
85
+ The number of channels in the `encoder_hidden_states`.
86
+ scale (`float`, defaults to 1.0):
87
+ the weight scale of image prompt.
88
+ """
89
+
90
+ def __init__(self, hidden_size, cross_attention_dim=None):
91
+ super().__init__()
92
+ self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
93
+ self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
94
+
95
+ def __call__(
96
+ self,
97
+ attn,
98
+ hidden_states,
99
+ encoder_hidden_states=None,
100
+ attention_mask=None,
101
+ temb=None,
102
+ id_embedding=None,
103
+ id_scale=1.0,
104
+ ):
105
+ residual = hidden_states
106
+
107
+ if attn.spatial_norm is not None:
108
+ hidden_states = attn.spatial_norm(hidden_states, temb)
109
+
110
+ input_ndim = hidden_states.ndim
111
+
112
+ if input_ndim == 4:
113
+ batch_size, channel, height, width = hidden_states.shape
114
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
115
+
116
+ batch_size, sequence_length, _ = (
117
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
118
+ )
119
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
120
+
121
+ if attn.group_norm is not None:
122
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
123
+
124
+ query = attn.to_q(hidden_states)
125
+
126
+ if encoder_hidden_states is None:
127
+ encoder_hidden_states = hidden_states
128
+ elif attn.norm_cross:
129
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
130
+
131
+ key = attn.to_k(encoder_hidden_states)
132
+ value = attn.to_v(encoder_hidden_states)
133
+
134
+ query = attn.head_to_batch_dim(query)
135
+ key = attn.head_to_batch_dim(key)
136
+ value = attn.head_to_batch_dim(value)
137
+
138
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
139
+ hidden_states = torch.bmm(attention_probs, value)
140
+ hidden_states = attn.batch_to_head_dim(hidden_states)
141
+
142
+ # for id-adapter
143
+ if id_embedding is not None:
144
+ if NUM_ZERO == 0:
145
+ id_key = self.id_to_k(id_embedding)
146
+ id_value = self.id_to_v(id_embedding)
147
+ else:
148
+ zero_tensor = torch.zeros(
149
+ (id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
150
+ dtype=id_embedding.dtype,
151
+ device=id_embedding.device,
152
+ )
153
+ id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1))
154
+ id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1))
155
+
156
+ id_key = attn.head_to_batch_dim(id_key).to(query.dtype)
157
+ id_value = attn.head_to_batch_dim(id_value).to(query.dtype)
158
+
159
+ id_attention_probs = attn.get_attention_scores(query, id_key, None)
160
+ id_hidden_states = torch.bmm(id_attention_probs, id_value)
161
+ id_hidden_states = attn.batch_to_head_dim(id_hidden_states)
162
+
163
+ if not ORTHO:
164
+ hidden_states = hidden_states + id_scale * id_hidden_states
165
+ else:
166
+ projection = (
167
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
168
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
169
+ * hidden_states
170
+ )
171
+ orthogonal = id_hidden_states - projection
172
+ hidden_states = hidden_states + id_scale * orthogonal
173
+
174
+ # linear proj
175
+ hidden_states = attn.to_out[0](hidden_states)
176
+ # dropout
177
+ hidden_states = attn.to_out[1](hidden_states)
178
+
179
+ if input_ndim == 4:
180
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
181
+
182
+ if attn.residual_connection:
183
+ hidden_states = hidden_states + residual
184
+
185
+ hidden_states = hidden_states / attn.rescale_output_factor
186
+
187
+ return hidden_states
188
+
189
+
190
+ class AttnProcessor2_0(nn.Module):
191
+ r"""
192
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
193
+ """
194
+
195
+ def __init__(self):
196
+ super().__init__()
197
+ if not hasattr(F, "scaled_dot_product_attention"):
198
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
199
+
200
+ def __call__(
201
+ self,
202
+ attn,
203
+ hidden_states,
204
+ encoder_hidden_states=None,
205
+ attention_mask=None,
206
+ temb=None,
207
+ id_embedding=None,
208
+ id_scale=1.0,
209
+ ):
210
+ residual = hidden_states
211
+
212
+ if attn.spatial_norm is not None:
213
+ hidden_states = attn.spatial_norm(hidden_states, temb)
214
+
215
+ input_ndim = hidden_states.ndim
216
+
217
+ if input_ndim == 4:
218
+ batch_size, channel, height, width = hidden_states.shape
219
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
220
+
221
+ batch_size, sequence_length, _ = (
222
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
223
+ )
224
+
225
+ if attention_mask is not None:
226
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
227
+ # scaled_dot_product_attention expects attention_mask shape to be
228
+ # (batch, heads, source_length, target_length)
229
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
230
+
231
+ if attn.group_norm is not None:
232
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
233
+
234
+ query = attn.to_q(hidden_states)
235
+
236
+ if encoder_hidden_states is None:
237
+ encoder_hidden_states = hidden_states
238
+ elif attn.norm_cross:
239
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
240
+
241
+ key = attn.to_k(encoder_hidden_states)
242
+ value = attn.to_v(encoder_hidden_states)
243
+
244
+ inner_dim = key.shape[-1]
245
+ head_dim = inner_dim // attn.heads
246
+
247
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
248
+
249
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
250
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+
252
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
253
+ hidden_states = F.scaled_dot_product_attention(
254
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
255
+ )
256
+
257
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
258
+ hidden_states = hidden_states.to(query.dtype)
259
+
260
+ # linear proj
261
+ hidden_states = attn.to_out[0](hidden_states)
262
+ # dropout
263
+ hidden_states = attn.to_out[1](hidden_states)
264
+
265
+ if input_ndim == 4:
266
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
267
+
268
+ if attn.residual_connection:
269
+ hidden_states = hidden_states + residual
270
+
271
+ hidden_states = hidden_states / attn.rescale_output_factor
272
+
273
+ return hidden_states
274
+
275
+
276
+ class IDAttnProcessor2_0(torch.nn.Module):
277
+ r"""
278
+ Attention processor for ID-Adapater for PyTorch 2.0.
279
+ Args:
280
+ hidden_size (`int`):
281
+ The hidden size of the attention layer.
282
+ cross_attention_dim (`int`):
283
+ The number of channels in the `encoder_hidden_states`.
284
+ """
285
+
286
+ def __init__(self, hidden_size, cross_attention_dim=None):
287
+ super().__init__()
288
+ if not hasattr(F, "scaled_dot_product_attention"):
289
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
290
+
291
+ self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
292
+ self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
293
+
294
+ def __call__(
295
+ self,
296
+ attn,
297
+ hidden_states,
298
+ encoder_hidden_states=None,
299
+ attention_mask=None,
300
+ temb=None,
301
+ id_embedding=None,
302
+ id_scale=1.0,
303
+ ):
304
+ residual = hidden_states
305
+
306
+ if attn.spatial_norm is not None:
307
+ hidden_states = attn.spatial_norm(hidden_states, temb)
308
+
309
+ input_ndim = hidden_states.ndim
310
+
311
+ if input_ndim == 4:
312
+ batch_size, channel, height, width = hidden_states.shape
313
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
314
+
315
+ batch_size, sequence_length, _ = (
316
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
317
+ )
318
+
319
+ if attention_mask is not None:
320
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
321
+ # scaled_dot_product_attention expects attention_mask shape to be
322
+ # (batch, heads, source_length, target_length)
323
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
324
+
325
+ if attn.group_norm is not None:
326
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
327
+
328
+ query = attn.to_q(hidden_states)
329
+
330
+ if encoder_hidden_states is None:
331
+ encoder_hidden_states = hidden_states
332
+ elif attn.norm_cross:
333
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
334
+
335
+ key = attn.to_k(encoder_hidden_states)
336
+ value = attn.to_v(encoder_hidden_states)
337
+
338
+ inner_dim = key.shape[-1]
339
+ head_dim = inner_dim // attn.heads
340
+
341
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
+
343
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
345
+
346
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
347
+ hidden_states = F.scaled_dot_product_attention(
348
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
349
+ )
350
+
351
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
352
+ hidden_states = hidden_states.to(query.dtype)
353
+
354
+ # for id embedding
355
+ if id_embedding is not None:
356
+ if NUM_ZERO == 0:
357
+ id_key = self.id_to_k(id_embedding).to(query.dtype)
358
+ id_value = self.id_to_v(id_embedding).to(query.dtype)
359
+ else:
360
+ zero_tensor = torch.zeros(
361
+ (id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
362
+ dtype=id_embedding.dtype,
363
+ device=id_embedding.device,
364
+ )
365
+ id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
366
+ id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
367
+
368
+ id_key = id_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
369
+ id_value = id_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
370
+
371
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
372
+ id_hidden_states = F.scaled_dot_product_attention(
373
+ query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False
374
+ )
375
+
376
+ id_hidden_states = id_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
377
+ id_hidden_states = id_hidden_states.to(query.dtype)
378
+
379
+ if not ORTHO and not ORTHO_v2:
380
+ hidden_states = hidden_states + id_scale * id_hidden_states
381
+ elif ORTHO_v2:
382
+ orig_dtype = hidden_states.dtype
383
+ hidden_states = hidden_states.to(torch.float32)
384
+ id_hidden_states = id_hidden_states.to(torch.float32)
385
+ attn_map = query @ id_key.transpose(-2, -1)
386
+ attn_mean = attn_map.softmax(dim=-1).mean(dim=1)
387
+ attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True)
388
+ projection = (
389
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
390
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
391
+ * hidden_states
392
+ )
393
+ orthogonal = id_hidden_states + (attn_mean - 1) * projection
394
+ hidden_states = hidden_states + id_scale * orthogonal
395
+ hidden_states = hidden_states.to(orig_dtype)
396
+ else:
397
+ orig_dtype = hidden_states.dtype
398
+ hidden_states = hidden_states.to(torch.float32)
399
+ id_hidden_states = id_hidden_states.to(torch.float32)
400
+ projection = (
401
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
402
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
403
+ * hidden_states
404
+ )
405
+ orthogonal = id_hidden_states - projection
406
+ hidden_states = hidden_states + id_scale * orthogonal
407
+ hidden_states = hidden_states.to(orig_dtype)
408
+
409
+ # linear proj
410
+ hidden_states = attn.to_out[0](hidden_states)
411
+ # dropout
412
+ hidden_states = attn.to_out[1](hidden_states)
413
+
414
+ if input_ndim == 4:
415
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
416
+
417
+ if attn.residual_connection:
418
+ hidden_states = hidden_states + residual
419
+
420
+ hidden_states = hidden_states / attn.rescale_output_factor
421
+
422
+ return hidden_states