linoyts HF staff commited on
Commit
39c8554
1 Parent(s): 58ec06e

Create clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +421 -0
clip_slider_pipeline.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import diffusers
2
+ import torch
3
+ import random
4
+ from tqdm import tqdm
5
+ from constants import SUBJECTS, MEDIUMS
6
+ from PIL import Image
7
+
8
+ class CLIPSlider:
9
+ def __init__(
10
+ self,
11
+ sd_pipe,
12
+ device: torch.device,
13
+ target_word: str,
14
+ opposite: str,
15
+ target_word_2nd: str = "",
16
+ opposite_2nd: str = "",
17
+ iterations: int = 300,
18
+ ):
19
+
20
+ self.device = device
21
+ self.pipe = sd_pipe.to(self.device)
22
+ self.iterations = iterations
23
+ self.avg_diff = self.find_latent_direction(target_word, opposite)
24
+ if target_word_2nd != "" or opposite_2nd != "":
25
+ self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
26
+ else:
27
+ self.avg_diff_2nd = None
28
+
29
+
30
+ def find_latent_direction(self,
31
+ target_word:str,
32
+ opposite:str):
33
+
34
+ # lets identify a latent direction by taking differences between opposites
35
+ # target_word = "happy"
36
+ # opposite = "sad"
37
+
38
+
39
+ with torch.no_grad():
40
+ positives = []
41
+ negatives = []
42
+ for i in tqdm(range(self.iterations)):
43
+ medium = random.choice(MEDIUMS)
44
+ subject = random.choice(SUBJECTS)
45
+ pos_prompt = f"a {medium} of a {target_word} {subject}"
46
+ neg_prompt = f"a {medium} of a {opposite} {subject}"
47
+ pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
48
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
49
+ neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
50
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
51
+ pos = self.pipe.text_encoder(pos_toks).pooler_output
52
+ neg = self.pipe.text_encoder(neg_toks).pooler_output
53
+ positives.append(pos)
54
+ negatives.append(neg)
55
+
56
+ positives = torch.cat(positives, dim=0)
57
+ negatives = torch.cat(negatives, dim=0)
58
+
59
+ diffs = positives - negatives
60
+
61
+ avg_diff = diffs.mean(0, keepdim=True)
62
+ return avg_diff
63
+
64
+
65
+ def generate(self,
66
+ prompt = "a photo of a house",
67
+ scale = 2.,
68
+ scale_2nd = 0., # scale for the 2nd dim directions when avg_diff_2nd is not None
69
+ seed = 15,
70
+ only_pooler = False,
71
+ normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
72
+ correlation_weight_factor = 1.0,
73
+ **pipeline_kwargs
74
+ ):
75
+ # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
76
+ # if pooler token only [-4,4] work well
77
+
78
+ with torch.no_grad():
79
+ toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True,
80
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
81
+ prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
82
+
83
+ if self.avg_diff_2nd and normalize_scales:
84
+ denominator = abs(scale) + abs(scale_2nd)
85
+ scale = scale / denominator
86
+ scale_2nd = scale_2nd / denominator
87
+ if only_pooler:
88
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
89
+ if self.avg_diff_2nd:
90
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
91
+ else:
92
+ normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
93
+ sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
94
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
95
+
96
+ standard_weights = torch.ones_like(weights)
97
+
98
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
99
+
100
+ # weights = torch.sigmoid((weights-0.5)*7)
101
+ prompt_embeds = prompt_embeds + (
102
+ weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
103
+ if self.avg_diff_2nd:
104
+ prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
105
+
106
+
107
+ torch.manual_seed(seed)
108
+ image = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images
109
+
110
+ return image
111
+
112
+ def spectrum(self,
113
+ prompt="a photo of a house",
114
+ low_scale=-2,
115
+ low_scale_2nd=-2,
116
+ high_scale=2,
117
+ high_scale_2nd=2,
118
+ steps=5,
119
+ seed=15,
120
+ only_pooler=False,
121
+ normalize_scales=False,
122
+ correlation_weight_factor=1.0,
123
+ **pipeline_kwargs
124
+ ):
125
+
126
+ images = []
127
+ for i in range(steps):
128
+ scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
129
+ scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1)
130
+ image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs)
131
+ images.append(image[0])
132
+
133
+ canvas = Image.new('RGB', (640 * steps, 640))
134
+ for i, im in enumerate(images):
135
+ canvas.paste(im, (640 * i, 0))
136
+
137
+ return canvas
138
+
139
+ class CLIPSliderXL(CLIPSlider):
140
+
141
+ def find_latent_direction(self,
142
+ target_word:str,
143
+ opposite:str):
144
+
145
+ # lets identify a latent direction by taking differences between opposites
146
+ # target_word = "happy"
147
+ # opposite = "sad"
148
+
149
+
150
+ with torch.no_grad():
151
+ positives = []
152
+ negatives = []
153
+ positives2 = []
154
+ negatives2 = []
155
+ for i in tqdm(range(self.iterations)):
156
+ medium = random.choice(MEDIUMS)
157
+ subject = random.choice(SUBJECTS)
158
+ pos_prompt = f"a {medium} of a {target_word} {subject}"
159
+ neg_prompt = f"a {medium} of a {opposite} {subject}"
160
+
161
+ pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
162
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
163
+ neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
164
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
165
+ pos = self.pipe.text_encoder(pos_toks).pooler_output
166
+ neg = self.pipe.text_encoder(neg_toks).pooler_output
167
+ positives.append(pos)
168
+ negatives.append(neg)
169
+
170
+ pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
171
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
172
+ neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
173
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
174
+ pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
175
+ neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
176
+ positives2.append(pos2)
177
+ negatives2.append(neg2)
178
+
179
+ positives = torch.cat(positives, dim=0)
180
+ negatives = torch.cat(negatives, dim=0)
181
+ diffs = positives - negatives
182
+ avg_diff = diffs.mean(0, keepdim=True)
183
+
184
+ positives2 = torch.cat(positives2, dim=0)
185
+ negatives2 = torch.cat(negatives2, dim=0)
186
+ diffs2 = positives2 - negatives2
187
+ avg_diff2 = diffs2.mean(0, keepdim=True)
188
+ return (avg_diff, avg_diff2)
189
+
190
+ def generate(self,
191
+ prompt = "a photo of a house",
192
+ scale = 2,
193
+ scale_2nd = 2,
194
+ seed = 15,
195
+ only_pooler = False,
196
+ normalize_scales = False,
197
+ correlation_weight_factor = 1.0,
198
+ **pipeline_kwargs
199
+ ):
200
+ # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
201
+ # if pooler token only [-4,4] work well
202
+
203
+ text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
204
+ tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
205
+ with torch.no_grad():
206
+ # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda()
207
+ # prompt_embeds = pipe.text_encoder(toks).last_hidden_state
208
+
209
+ prompt_embeds_list = []
210
+
211
+ for i, text_encoder in enumerate(text_encoders):
212
+
213
+ tokenizer = tokenizers[i]
214
+ text_inputs = tokenizer(
215
+ prompt,
216
+ padding="max_length",
217
+ max_length=tokenizer.model_max_length,
218
+ truncation=True,
219
+ return_tensors="pt",
220
+ )
221
+ toks = text_inputs.input_ids
222
+
223
+ prompt_embeds = text_encoder(
224
+ toks.to(text_encoder.device),
225
+ output_hidden_states=True,
226
+ )
227
+
228
+ # We are only ALWAYS interested in the pooled output of the final text encoder
229
+ pooled_prompt_embeds = prompt_embeds[0]
230
+ prompt_embeds = prompt_embeds.hidden_states[-2]
231
+
232
+ if self.avg_diff_2nd and normalize_scales:
233
+ denominator = abs(scale) + abs(scale_2nd)
234
+ scale = scale / denominator
235
+ scale_2nd = scale_2nd / denominator
236
+ if only_pooler:
237
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
238
+ if self.avg_diff_2nd:
239
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
240
+ else:
241
+ normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
242
+ sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
243
+
244
+ if i == 0:
245
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
246
+
247
+ standard_weights = torch.ones_like(weights)
248
+
249
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
250
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
251
+ if self.avg_diff_2nd:
252
+ prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
253
+ else:
254
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
255
+
256
+ standard_weights = torch.ones_like(weights)
257
+
258
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
259
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
260
+ if self.avg_diff_2nd:
261
+ prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
262
+
263
+ bs_embed, seq_len, _ = prompt_embeds.shape
264
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
265
+ prompt_embeds_list.append(prompt_embeds)
266
+
267
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
268
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
269
+
270
+ torch.manual_seed(seed)
271
+ image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
272
+ **pipeline_kwargs).images
273
+
274
+ return image
275
+
276
+
277
+ class CLIPSlider3(CLIPSlider):
278
+ def find_latent_direction(self,
279
+ target_word:str,
280
+ opposite:str):
281
+
282
+ # lets identify a latent direction by taking differences between opposites
283
+ # target_word = "happy"
284
+ # opposite = "sad"
285
+
286
+
287
+ with torch.no_grad():
288
+ positives = []
289
+ negatives = []
290
+ positives2 = []
291
+ negatives2 = []
292
+ for i in tqdm(range(self.iterations)):
293
+ medium = random.choice(MEDIUMS)
294
+ subject = random.choice(SUBJECTS)
295
+ pos_prompt = f"a {medium} of a {target_word} {subject}"
296
+ neg_prompt = f"a {medium} of a {opposite} {subject}"
297
+
298
+ pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
299
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
300
+ neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
301
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
302
+ pos = self.pipe.text_encoder(pos_toks).text_embeds
303
+ neg = self.pipe.text_encoder(neg_toks).text_embeds
304
+ positives.append(pos)
305
+ negatives.append(neg)
306
+
307
+ pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
308
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
309
+ neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
310
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
311
+ pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
312
+ neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
313
+ positives2.append(pos2)
314
+ negatives2.append(neg2)
315
+
316
+ positives = torch.cat(positives, dim=0)
317
+ negatives = torch.cat(negatives, dim=0)
318
+ diffs = positives - negatives
319
+ avg_diff = diffs.mean(0, keepdim=True)
320
+
321
+ positives2 = torch.cat(positives2, dim=0)
322
+ negatives2 = torch.cat(negatives2, dim=0)
323
+ diffs2 = positives2 - negatives2
324
+ avg_diff2 = diffs2.mean(0, keepdim=True)
325
+ return (avg_diff, avg_diff2)
326
+
327
+ def generate(self,
328
+ prompt = "a photo of a house",
329
+ scale = 2,
330
+ seed = 15,
331
+ only_pooler = False,
332
+ correlation_weight_factor = 1.0,
333
+ ** pipeline_kwargs
334
+ ):
335
+ # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
336
+ # if pooler token only [-4,4] work well
337
+ clip_text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
338
+ clip_tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
339
+ with torch.no_grad():
340
+ # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda()
341
+ # prompt_embeds = pipe.text_encoder(toks).last_hidden_state
342
+
343
+ clip_prompt_embeds_list = []
344
+ clip_pooled_prompt_embeds_list = []
345
+ for i, text_encoder in enumerate(clip_text_encoders):
346
+
347
+ if i < 2:
348
+ tokenizer = clip_tokenizers[i]
349
+ text_inputs = tokenizer(
350
+ prompt,
351
+ padding="max_length",
352
+ max_length=tokenizer.model_max_length,
353
+ truncation=True,
354
+ return_tensors="pt",
355
+ )
356
+ toks = text_inputs.input_ids
357
+
358
+ prompt_embeds = text_encoder(
359
+ toks.to(text_encoder.device),
360
+ output_hidden_states=True,
361
+ )
362
+
363
+ # We are only ALWAYS interested in the pooled output of the final text encoder
364
+ pooled_prompt_embeds = prompt_embeds[0]
365
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
366
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
367
+ prompt_embeds = prompt_embeds.hidden_states[-2]
368
+ else:
369
+ text_inputs = self.pipe.tokenizer_3(
370
+ prompt,
371
+ padding="max_length",
372
+ max_length=self.tokenizer_max_length,
373
+ truncation=True,
374
+ add_special_tokens=True,
375
+ return_tensors="pt",
376
+ )
377
+ toks = text_inputs.input_ids
378
+ prompt_embeds = self.pipe.text_encoder_3(toks.to(self.device))[0]
379
+ t5_prompt_embed_shape = prompt_embeds.shape[-1]
380
+
381
+ if only_pooler:
382
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
383
+ else:
384
+ normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
385
+ sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
386
+ if i == 0:
387
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
388
+
389
+ standard_weights = torch.ones_like(weights)
390
+
391
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
392
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
393
+ else:
394
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
395
+
396
+ standard_weights = torch.ones_like(weights)
397
+
398
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
399
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
400
+
401
+ bs_embed, seq_len, _ = prompt_embeds.shape
402
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
403
+ if i < 2:
404
+ clip_prompt_embeds_list.append(prompt_embeds)
405
+
406
+ clip_prompt_embeds = torch.concat(clip_prompt_embeds_list, dim=-1)
407
+ clip_pooled_prompt_embeds = torch.concat(clip_pooled_prompt_embeds_list, dim=-1)
408
+
409
+ clip_prompt_embeds = torch.nn.functional.pad(
410
+ clip_prompt_embeds, (0, t5_prompt_embed_shape - clip_prompt_embeds.shape[-1])
411
+ )
412
+
413
+ prompt_embeds = torch.cat([clip_prompt_embeds, prompt_embeds], dim=-2)
414
+
415
+
416
+
417
+ torch.manual_seed(seed)
418
+ image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=clip_pooled_prompt_embeds,
419
+ **pipeline_kwargs).images
420
+
421
+ return image