jadechoghari commited on
Commit
da7256e
1 Parent(s): 7b5beb5

Create pnp_utils.py

Browse files
Files changed (1) hide show
  1. pnp_utils.py +172 -0
pnp_utils.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import random
4
+ import numpy as np
5
+
6
+ def seed_everything(seed):
7
+ torch.manual_seed(seed)
8
+ torch.cuda.manual_seed(seed)
9
+ random.seed(seed)
10
+ np.random.seed(seed)
11
+
12
+ def register_time(model, t):
13
+ # register current timestamp to each layer
14
+ down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1], 3: [0, 1]}
15
+ up_res_dict = {0:[0, 1, 2], 1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
16
+ for res in up_res_dict:
17
+ for block in up_res_dict[res]:
18
+ if hasattr(model.unet.up_blocks[res], "attentions"):
19
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
20
+ setattr(module, 't', t)
21
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn2
22
+ setattr(module, 't', t)
23
+ conv_module = model.unet.up_blocks[res].resnets[block]
24
+ setattr(conv_module, 't', t)
25
+ for res in down_res_dict:
26
+ for block in down_res_dict[res]:
27
+ if hasattr(model.unet.down_blocks[res], "attentions"):
28
+ module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
29
+ setattr(module, 't', t)
30
+ module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn2
31
+ setattr(module, 't', t)
32
+ conv_module = model.unet.down_blocks[res].resnets[block]
33
+ setattr(conv_module, 't', t)
34
+ module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
35
+ setattr(module, 't', t)
36
+ module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn2
37
+ setattr(module, 't', t)
38
+
39
+ def register_attention_control(model, injection_schedule, num_inputs):
40
+ def sa_forward(self):
41
+ to_out = self.to_out
42
+ if type(to_out) is torch.nn.modules.container.ModuleList:
43
+ to_out = self.to_out[0]
44
+ else:
45
+ to_out = self.to_out
46
+
47
+ def forward(x, encoder_hidden_states=None, attention_mask=None, **kwargs):
48
+ batch_size, sequence_length, dim = x.shape
49
+ h = self.heads
50
+
51
+ is_cross = encoder_hidden_states is not None
52
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
53
+
54
+ v = self.to_v(encoder_hidden_states)
55
+ v = self.head_to_batch_dim(v)
56
+
57
+ if not is_cross and self.injection_schedule is not None and (
58
+ self.t in self.injection_schedule or self.t == 1000):
59
+ q = self.to_q(x)
60
+ k = self.to_k(encoder_hidden_states)
61
+
62
+ source_batch_size = int(q.shape[0] // num_inputs)
63
+
64
+ q = q[:source_batch_size]
65
+ k = k[:source_batch_size]
66
+ q = self.head_to_batch_dim(q)
67
+ k = self.head_to_batch_dim(k)
68
+
69
+ else:
70
+ q = self.to_q(x)
71
+ k = self.to_k(encoder_hidden_states)
72
+ q = self.head_to_batch_dim(q)
73
+ k = self.head_to_batch_dim(k)
74
+
75
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
76
+
77
+ if attention_mask is not None:
78
+ attention_mask = attention_mask.reshape(batch_size, -1)
79
+ max_neg_value = -torch.finfo(sim.dtype).max
80
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
81
+ sim.masked_fill_(~attention_mask, max_neg_value)
82
+
83
+ # attention, what we cannot get enough of
84
+ attn = sim.softmax(dim=-1)
85
+
86
+ if not is_cross and self.injection_schedule is not None and (
87
+ self.t in self.injection_schedule or self.t == 1000):
88
+ # Inject attention map from source
89
+ # attn = torch.cat([attn] * num_inputs, dim = 0)
90
+ attn = attn.repeat(num_inputs, 1, 1)
91
+
92
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
93
+ out = self.batch_to_head_dim(out)
94
+
95
+ return to_out(out)
96
+
97
+ return forward
98
+
99
+ # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
100
+ res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
101
+ for res in res_dict:
102
+ for block in res_dict[res]:
103
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
104
+ module.forward = sa_forward(module)
105
+ setattr(module, 'injection_schedule', injection_schedule)
106
+ print("[INFO-PnP] Register Source Attention QK Injection in Up Res", res_dict)
107
+
108
+ def register_conv_control(model, injection_schedule, num_inputs):
109
+ def conv_forward(self):
110
+ def forward(input_tensor, temb, **kwargs):
111
+ hidden_states = input_tensor
112
+
113
+ hidden_states = self.norm1(hidden_states)
114
+ hidden_states = self.nonlinearity(hidden_states)
115
+
116
+ if self.upsample is not None:
117
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
118
+ if hidden_states.shape[0] >= 64:
119
+ input_tensor = input_tensor.contiguous()
120
+ hidden_states = hidden_states.contiguous()
121
+ input_tensor = self.upsample(input_tensor)
122
+ hidden_states = self.upsample(hidden_states)
123
+ elif self.downsample is not None:
124
+ input_tensor = self.downsample(input_tensor)
125
+ hidden_states = self.downsample(hidden_states)
126
+
127
+ hidden_states = self.conv1(hidden_states)
128
+
129
+ if temb is not None:
130
+ temb = self.time_emb_proj(self.nonlinearity(temb))[
131
+ :, :, None, None]
132
+
133
+ if temb is not None and self.time_embedding_norm == "default":
134
+ hidden_states = hidden_states + temb
135
+
136
+ hidden_states = self.norm2(hidden_states)
137
+
138
+ if temb is not None and self.time_embedding_norm == "scale_shift":
139
+ scale, shift = torch.chunk(temb, 2, dim=1)
140
+ hidden_states = hidden_states * (1 + scale) + shift
141
+
142
+ hidden_states = self.nonlinearity(hidden_states)
143
+
144
+ hidden_states = self.dropout(hidden_states)
145
+ hidden_states = self.conv2(hidden_states)
146
+ if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
147
+ source_batch_size = int(hidden_states.shape[0] // num_inputs)
148
+
149
+ # inject unconditional
150
+ hidden_states[source_batch_size:2 *
151
+ source_batch_size] = hidden_states[:source_batch_size]
152
+ # inject conditional
153
+ if num_inputs > 2:
154
+ hidden_states[2 * source_batch_size:3 *
155
+ source_batch_size] = hidden_states[:source_batch_size]
156
+
157
+
158
+ if self.conv_shortcut is not None:
159
+ input_tensor = self.conv_shortcut(input_tensor)
160
+
161
+ output_tensor = (input_tensor + hidden_states) / \
162
+ self.output_scale_factor
163
+
164
+ return output_tensor
165
+
166
+ return forward
167
+
168
+ res_dict = {1: [1]}
169
+ conv_module = model.unet.up_blocks[1].resnets[1]
170
+ conv_module.forward = conv_forward(conv_module)
171
+ setattr(conv_module, 'injection_schedule', injection_schedule)
172
+ print("[INFO-PnP] Register Source Feature Injection in Up Res", res_dict)