Spaces:
Running
on
Zero
Running
on
Zero
take avg_diff out of attributes and save to cpu (#1)
Browse files- take avg_diff out of attributes and save to cpu (a64b549e3ec96a9cf832c237af6bd7dd3baae807)
- Update clip_slider_pipeline.py (224212aeaf4f1ed103d432f5d08e8eee69601810)
- app.py +9 -5
- clip_slider_pipeline.py +22 -165
app.py
CHANGED
@@ -17,14 +17,14 @@ def generate(slider_x, slider_y, prompt,
|
|
17 |
|
18 |
# check if avg diff for directions need to be re-calculated
|
19 |
if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
|
20 |
-
|
21 |
x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
|
22 |
|
23 |
if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
|
24 |
-
|
25 |
y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
|
26 |
|
27 |
-
image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=8)
|
28 |
comma_concepts_x = ', '.join(slider_x)
|
29 |
comma_concepts_y = ', '.join(slider_y)
|
30 |
|
@@ -36,11 +36,15 @@ def generate(slider_x, slider_y, prompt,
|
|
36 |
return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, image
|
37 |
|
38 |
def update_x(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
|
39 |
-
|
|
|
|
|
40 |
return image
|
41 |
|
42 |
def update_y(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
|
43 |
-
|
|
|
|
|
44 |
return image
|
45 |
|
46 |
css = '''
|
|
|
17 |
|
18 |
# check if avg diff for directions need to be re-calculated
|
19 |
if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
|
20 |
+
avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1])
|
21 |
x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
|
22 |
|
23 |
if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
|
24 |
+
avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[1])
|
25 |
y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
|
26 |
|
27 |
+
image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=8, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
|
28 |
comma_concepts_x = ', '.join(slider_x)
|
29 |
comma_concepts_y = ', '.join(slider_y)
|
30 |
|
|
|
36 |
return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, image
|
37 |
|
38 |
def update_x(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
|
39 |
+
avg_diff = [avg_diff_x_1.cuda(), avg_diff_x_2.cuda()]
|
40 |
+
avg_diff_2nd = [avg_diff_y_1.cuda(), avg_diff_y_2.cuda()]
|
41 |
+
image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
|
42 |
return image
|
43 |
|
44 |
def update_y(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
|
45 |
+
avg_diff = [avg_diff_x_1.cuda(), avg_diff_x_2.cuda()]
|
46 |
+
avg_diff_2nd = [avg_diff_y_1.cuda(), avg_diff_y_2.cuda()]
|
47 |
+
image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
|
48 |
return image
|
49 |
|
50 |
css = '''
|
clip_slider_pipeline.py
CHANGED
@@ -73,6 +73,8 @@ class CLIPSlider:
|
|
73 |
only_pooler = False,
|
74 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
75 |
correlation_weight_factor = 1.0,
|
|
|
|
|
76 |
**pipeline_kwargs
|
77 |
):
|
78 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
@@ -83,14 +85,14 @@ class CLIPSlider:
|
|
83 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
84 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
85 |
|
86 |
-
if
|
87 |
denominator = abs(scale) + abs(scale_2nd)
|
88 |
scale = scale / denominator
|
89 |
scale_2nd = scale_2nd / denominator
|
90 |
if only_pooler:
|
91 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
92 |
-
if
|
93 |
-
prompt_embeds[:, toks.argmax()] +=
|
94 |
else:
|
95 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
96 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
@@ -102,9 +104,9 @@ class CLIPSlider:
|
|
102 |
|
103 |
# weights = torch.sigmoid((weights-0.5)*7)
|
104 |
prompt_embeds = prompt_embeds + (
|
105 |
-
weights *
|
106 |
-
if
|
107 |
-
prompt_embeds += weights *
|
108 |
|
109 |
|
110 |
torch.manual_seed(seed)
|
@@ -198,6 +200,8 @@ class CLIPSliderXL(CLIPSlider):
|
|
198 |
only_pooler = False,
|
199 |
normalize_scales = False,
|
200 |
correlation_weight_factor = 1.0,
|
|
|
|
|
201 |
**pipeline_kwargs
|
202 |
):
|
203 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
@@ -232,16 +236,16 @@ class CLIPSliderXL(CLIPSlider):
|
|
232 |
pooled_prompt_embeds = prompt_embeds[0]
|
233 |
prompt_embeds = prompt_embeds.hidden_states[-2]
|
234 |
|
235 |
-
if
|
236 |
denominator = abs(scale) + abs(scale_2nd)
|
237 |
scale = scale / denominator
|
238 |
scale_2nd = scale_2nd / denominator
|
239 |
if only_pooler:
|
240 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
241 |
-
if
|
242 |
-
prompt_embeds[:, toks.argmax()] +=
|
243 |
else:
|
244 |
-
print(
|
245 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
246 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
247 |
|
@@ -251,18 +255,18 @@ class CLIPSliderXL(CLIPSlider):
|
|
251 |
standard_weights = torch.ones_like(weights)
|
252 |
|
253 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
254 |
-
prompt_embeds = prompt_embeds + (weights *
|
255 |
-
if
|
256 |
-
prompt_embeds += (weights *
|
257 |
else:
|
258 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
259 |
|
260 |
standard_weights = torch.ones_like(weights)
|
261 |
|
262 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
263 |
-
prompt_embeds = prompt_embeds + (weights *
|
264 |
-
if
|
265 |
-
prompt_embeds += (weights *
|
266 |
|
267 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
268 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
@@ -276,150 +280,3 @@ class CLIPSliderXL(CLIPSlider):
|
|
276 |
**pipeline_kwargs).images[0]
|
277 |
|
278 |
return image
|
279 |
-
|
280 |
-
|
281 |
-
class CLIPSlider3(CLIPSlider):
|
282 |
-
def find_latent_direction(self,
|
283 |
-
target_word:str,
|
284 |
-
opposite:str):
|
285 |
-
|
286 |
-
# lets identify a latent direction by taking differences between opposites
|
287 |
-
# target_word = "happy"
|
288 |
-
# opposite = "sad"
|
289 |
-
|
290 |
-
|
291 |
-
with torch.no_grad():
|
292 |
-
positives = []
|
293 |
-
negatives = []
|
294 |
-
positives2 = []
|
295 |
-
negatives2 = []
|
296 |
-
for i in tqdm(range(self.iterations)):
|
297 |
-
medium = random.choice(MEDIUMS)
|
298 |
-
subject = random.choice(SUBJECTS)
|
299 |
-
pos_prompt = f"a {medium} of a {target_word} {subject}"
|
300 |
-
neg_prompt = f"a {medium} of a {opposite} {subject}"
|
301 |
-
|
302 |
-
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
303 |
-
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
304 |
-
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
305 |
-
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
306 |
-
pos = self.pipe.text_encoder(pos_toks).text_embeds
|
307 |
-
neg = self.pipe.text_encoder(neg_toks).text_embeds
|
308 |
-
positives.append(pos)
|
309 |
-
negatives.append(neg)
|
310 |
-
|
311 |
-
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
312 |
-
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
313 |
-
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
314 |
-
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
315 |
-
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
|
316 |
-
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
|
317 |
-
positives2.append(pos2)
|
318 |
-
negatives2.append(neg2)
|
319 |
-
|
320 |
-
positives = torch.cat(positives, dim=0)
|
321 |
-
negatives = torch.cat(negatives, dim=0)
|
322 |
-
diffs = positives - negatives
|
323 |
-
avg_diff = diffs.mean(0, keepdim=True)
|
324 |
-
|
325 |
-
positives2 = torch.cat(positives2, dim=0)
|
326 |
-
negatives2 = torch.cat(negatives2, dim=0)
|
327 |
-
diffs2 = positives2 - negatives2
|
328 |
-
avg_diff2 = diffs2.mean(0, keepdim=True)
|
329 |
-
return (avg_diff, avg_diff2)
|
330 |
-
|
331 |
-
def generate(self,
|
332 |
-
prompt = "a photo of a house",
|
333 |
-
scale = 2,
|
334 |
-
seed = 15,
|
335 |
-
only_pooler = False,
|
336 |
-
correlation_weight_factor = 1.0,
|
337 |
-
** pipeline_kwargs
|
338 |
-
):
|
339 |
-
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
340 |
-
# if pooler token only [-4,4] work well
|
341 |
-
clip_text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
|
342 |
-
clip_tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
|
343 |
-
with torch.no_grad():
|
344 |
-
# toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda()
|
345 |
-
# prompt_embeds = pipe.text_encoder(toks).last_hidden_state
|
346 |
-
|
347 |
-
clip_prompt_embeds_list = []
|
348 |
-
clip_pooled_prompt_embeds_list = []
|
349 |
-
for i, text_encoder in enumerate(clip_text_encoders):
|
350 |
-
|
351 |
-
if i < 2:
|
352 |
-
tokenizer = clip_tokenizers[i]
|
353 |
-
text_inputs = tokenizer(
|
354 |
-
prompt,
|
355 |
-
padding="max_length",
|
356 |
-
max_length=tokenizer.model_max_length,
|
357 |
-
truncation=True,
|
358 |
-
return_tensors="pt",
|
359 |
-
)
|
360 |
-
toks = text_inputs.input_ids
|
361 |
-
|
362 |
-
prompt_embeds = text_encoder(
|
363 |
-
toks.to(text_encoder.device),
|
364 |
-
output_hidden_states=True,
|
365 |
-
)
|
366 |
-
|
367 |
-
# We are only ALWAYS interested in the pooled output of the final text encoder
|
368 |
-
pooled_prompt_embeds = prompt_embeds[0]
|
369 |
-
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
370 |
-
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
371 |
-
prompt_embeds = prompt_embeds.hidden_states[-2]
|
372 |
-
else:
|
373 |
-
text_inputs = self.pipe.tokenizer_3(
|
374 |
-
prompt,
|
375 |
-
padding="max_length",
|
376 |
-
max_length=self.tokenizer_max_length,
|
377 |
-
truncation=True,
|
378 |
-
add_special_tokens=True,
|
379 |
-
return_tensors="pt",
|
380 |
-
)
|
381 |
-
toks = text_inputs.input_ids
|
382 |
-
prompt_embeds = self.pipe.text_encoder_3(toks.to(self.device))[0]
|
383 |
-
t5_prompt_embed_shape = prompt_embeds.shape[-1]
|
384 |
-
|
385 |
-
if only_pooler:
|
386 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
|
387 |
-
else:
|
388 |
-
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
389 |
-
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
390 |
-
if i == 0:
|
391 |
-
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
|
392 |
-
|
393 |
-
standard_weights = torch.ones_like(weights)
|
394 |
-
|
395 |
-
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
396 |
-
prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
397 |
-
else:
|
398 |
-
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
399 |
-
|
400 |
-
standard_weights = torch.ones_like(weights)
|
401 |
-
|
402 |
-
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
403 |
-
prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
404 |
-
|
405 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
406 |
-
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
407 |
-
if i < 2:
|
408 |
-
clip_prompt_embeds_list.append(prompt_embeds)
|
409 |
-
|
410 |
-
clip_prompt_embeds = torch.concat(clip_prompt_embeds_list, dim=-1)
|
411 |
-
clip_pooled_prompt_embeds = torch.concat(clip_pooled_prompt_embeds_list, dim=-1)
|
412 |
-
|
413 |
-
clip_prompt_embeds = torch.nn.functional.pad(
|
414 |
-
clip_prompt_embeds, (0, t5_prompt_embed_shape - clip_prompt_embeds.shape[-1])
|
415 |
-
)
|
416 |
-
|
417 |
-
prompt_embeds = torch.cat([clip_prompt_embeds, prompt_embeds], dim=-2)
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
torch.manual_seed(seed)
|
422 |
-
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=clip_pooled_prompt_embeds,
|
423 |
-
**pipeline_kwargs).images[0]
|
424 |
-
|
425 |
-
return image
|
|
|
73 |
only_pooler = False,
|
74 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
75 |
correlation_weight_factor = 1.0,
|
76 |
+
avg_diff = None,
|
77 |
+
avg_diff_2nd = None,
|
78 |
**pipeline_kwargs
|
79 |
):
|
80 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
85 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
86 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
87 |
|
88 |
+
if avg_diff_2nd and normalize_scales:
|
89 |
denominator = abs(scale) + abs(scale_2nd)
|
90 |
scale = scale / denominator
|
91 |
scale_2nd = scale_2nd / denominator
|
92 |
if only_pooler:
|
93 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
|
94 |
+
if avg_diff_2nd:
|
95 |
+
prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
|
96 |
else:
|
97 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
98 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
104 |
|
105 |
# weights = torch.sigmoid((weights-0.5)*7)
|
106 |
prompt_embeds = prompt_embeds + (
|
107 |
+
weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
108 |
+
if avg_diff_2nd:
|
109 |
+
prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
|
110 |
|
111 |
|
112 |
torch.manual_seed(seed)
|
|
|
200 |
only_pooler = False,
|
201 |
normalize_scales = False,
|
202 |
correlation_weight_factor = 1.0,
|
203 |
+
avg_diff = None,
|
204 |
+
avg_diff_2nd = None,
|
205 |
**pipeline_kwargs
|
206 |
):
|
207 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
236 |
pooled_prompt_embeds = prompt_embeds[0]
|
237 |
prompt_embeds = prompt_embeds.hidden_states[-2]
|
238 |
|
239 |
+
if avg_diff_2nd and normalize_scales:
|
240 |
denominator = abs(scale) + abs(scale_2nd)
|
241 |
scale = scale / denominator
|
242 |
scale_2nd = scale_2nd / denominator
|
243 |
if only_pooler:
|
244 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
|
245 |
+
if avg_diff_2nd:
|
246 |
+
prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd
|
247 |
else:
|
248 |
+
print(avg_diff)
|
249 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
250 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
251 |
|
|
|
255 |
standard_weights = torch.ones_like(weights)
|
256 |
|
257 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
258 |
+
prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
259 |
+
if avg_diff_2nd:
|
260 |
+
prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
|
261 |
else:
|
262 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
263 |
|
264 |
standard_weights = torch.ones_like(weights)
|
265 |
|
266 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
267 |
+
prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
268 |
+
if avg_diff_2nd:
|
269 |
+
prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
|
270 |
|
271 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
272 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
|
|
280 |
**pipeline_kwargs).images[0]
|
281 |
|
282 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|