AlekseyCalvin commited on
Commit
c3f6e82
1 Parent(s): 20e1000

Upload pipeline13.py

Browse files
Files changed (1) hide show
  1. pipeline13.py +75 -74
pipeline13.py CHANGED
@@ -39,7 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
  BASE_SEQ_LEN = 256
40
  MAX_SEQ_LEN = 4096
41
  BASE_SHIFT = 0.5
42
- MAX_SHIFT = 1.2
43
 
44
  # Helper functions
45
  def calculate_timestep_shift(image_seq_len: int) -> float:
@@ -108,7 +108,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
108
  self,
109
  prompt: Union[str, List[str]] = None,
110
  num_images_per_prompt: int = 1,
111
- max_sequence_length: int = 256,
112
  device: Optional[torch.device] = None,
113
  dtype: Optional[torch.dtype] = None,
114
  ):
@@ -179,16 +179,16 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
179
  "The following part of your input was truncated because CLIP can only handle sequences up to"
180
  f" {self.tokenizer_max_length} tokens: {removed_text}"
181
  )
182
- prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True)
183
 
184
  # Use pooled output of CLIPTextModel
185
  prompt_embeds = prompt_embeds.pooler_output
186
  prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
187
 
188
- _, seq_len = prompt_embeds.shape
189
 
190
  # duplicate text embeddings for each generation per prompt, using mps friendly method
191
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
192
  prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
193
 
194
  return prompt_embeds
@@ -274,21 +274,13 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
274
  num_images_per_prompt=num_images_per_prompt,
275
  )
276
 
277
- t5_negative_prompt_embeds = self._get_t5_prompt_embeds(
278
  prompt=negative_prompt_2,
279
  device=device,
280
  num_images_per_prompt=num_images_per_prompt,
281
  max_sequence_length=max_sequence_length,
282
  )
283
 
284
- negative_pooled_prompt_embeds = torch.nn.functional.pad(
285
- negative_pooled_prompt_embeds,
286
- (0, t5_negative_prompt_embeds.shape[-1] - negative_pooled_prompt_embeds.shape[-1]),
287
- )
288
-
289
- negative_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, t5_negative_prompt_embeds], dim=-2)
290
-
291
-
292
  if self.text_encoder is not None:
293
  if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
294
  # Retrieve the original scale by scaling back the LoRA layers
@@ -300,18 +292,11 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
300
  unscale_lora_layers(self.text_encoder_2, lora_scale)
301
 
302
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
303
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
304
-
305
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
306
- pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
307
-
308
- negative_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
309
- negative_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
310
 
311
- negative_prompt_embeds = torch.unsqueeze(0)
312
- negative_pooled_prompt_embeds = torch.unsqueeze(0)
313
-
314
- return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
315
 
316
  def check_inputs(
317
  self,
@@ -319,8 +304,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
319
  prompt_2,
320
  height,
321
  width,
322
- negative_prompt=None,
323
- negative_prompt_2=None,
324
  prompt_embeds=None,
325
  negative_prompt_embeds=None,
326
  pooled_prompt_embeds=None,
@@ -354,7 +337,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
354
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
355
  )
356
  if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
357
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
358
 
359
  if max_sequence_length is not None and max_sequence_length > 512:
360
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -367,8 +350,9 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
367
 
368
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
369
 
 
370
  latent_image_ids = latent_image_ids.reshape(
371
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
372
  )
373
 
374
  return latent_image_ids.to(device=device, dtype=dtype)
@@ -394,6 +378,40 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
394
  latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
395
 
396
  return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
398
  def prepare_extra_step_kwargs(self, generator, eta):
399
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -441,39 +459,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
441
  """
442
  self.vae.disable_tiling()
443
 
444
- def prepare_latents(
445
- self,
446
- batch_size,
447
- num_channels_latents,
448
- height,
449
- width,
450
- dtype,
451
- device,
452
- generator,
453
- latents=None,
454
- ):
455
- height = 2 * (int(height) // self.vae_scale_factor)
456
- width = 2 * (int(width) // self.vae_scale_factor)
457
-
458
- shape = (batch_size, num_channels_latents, height, width)
459
-
460
- if latents is not None:
461
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
462
- return latents.to(device=device, dtype=dtype), latent_image_ids
463
-
464
- if isinstance(generator, list) and len(generator) != batch_size:
465
- raise ValueError(
466
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
467
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
468
- )
469
-
470
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
471
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
472
-
473
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
474
-
475
- return latents, latent_image_ids
476
-
477
  @property
478
  def guidance_scale(self):
479
  return self._guidance_scale
@@ -517,9 +502,10 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
517
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
518
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
519
  output_type: Optional[str] = "pil",
 
520
  return_dict: bool = True,
521
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
522
- max_sequence_length: int = 300,
523
  **kwargs,
524
  ):
525
  height = height or self.default_sample_size * self.vae_scale_factor
@@ -531,8 +517,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
531
  prompt_2,
532
  height,
533
  width,
534
- negative_prompt=negative_prompt,
535
- negative_prompt_2=negative_prompt_2,
536
  prompt_embeds=prompt_embeds,
537
  negative_prompt_embeds=negative_prompt_embeds,
538
  pooled_prompt_embeds=pooled_prompt_embeds,
@@ -543,9 +527,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
543
  self._guidance_scale = guidance_scale
544
  self._joint_attention_kwargs = joint_attention_kwargs
545
  self._interrupt = False
546
-
547
- do_classifier_free_guidance = guidance_scale > 1.0
548
-
549
  # 2. Define call parameters
550
  if prompt is not None and isinstance(prompt, str):
551
  batch_size = 1
@@ -565,6 +547,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
565
  text_ids,
566
  negative_prompt_embeds,
567
  negative_pooled_prompt_embeds,
 
568
  ) = self.encode_prompt(
569
  prompt=prompt,
570
  prompt_2=prompt_2,
@@ -583,7 +566,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
583
 
584
  if self.do_classifier_free_guidance:
585
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
586
- pooled_prompt_embeds = torch.cat([negative_prompt_embeds, pooled_prompt_embeds], dim=0)
587
 
588
  # 4. Prepare latent variables
589
  num_channels_latents = self.transformer.config.in_channels // 4
@@ -593,7 +576,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
593
  height,
594
  width,
595
  prompt_embeds.dtype,
596
- negative_prompt_embeds.dtype,
597
  device,
598
  generator,
599
  latents,
@@ -602,7 +584,13 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
602
  # 5. Prepare timesteps
603
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
604
  image_seq_len = latents.shape[1]
605
- mu = calculate_timestep_shift(image_seq_len)
 
 
 
 
 
 
606
  timesteps, num_inference_steps = prepare_timesteps(
607
  self.scheduler,
608
  num_inference_steps,
@@ -611,6 +599,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
611
  sigmas,
612
  mu=mu,
613
  )
 
614
  self._num_timesteps = len(timesteps)
615
 
616
  # 6. Denoising loop
@@ -629,7 +618,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
629
  else:
630
  guidance = None
631
 
632
- noise_pred = self.transformer(
633
  hidden_states=latent_model_input,
634
  timestep=timestep / 1000,
635
  guidance=guidance,
@@ -640,11 +629,23 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
640
  joint_attention_kwargs=self.joint_attention_kwargs,
641
  return_dict=False,
642
  )[0]
 
 
 
 
 
 
 
 
 
 
 
643
 
644
  if self.do_classifier_free_guidance:
645
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
646
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
647
-
 
648
  # compute the previous noisy sample x_t -> x_t-1
649
  latents_dtype = latents.dtype
650
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
39
  BASE_SEQ_LEN = 256
40
  MAX_SEQ_LEN = 4096
41
  BASE_SHIFT = 0.5
42
+ MAX_SHIFT = 1.16
43
 
44
  # Helper functions
45
  def calculate_timestep_shift(image_seq_len: int) -> float:
 
108
  self,
109
  prompt: Union[str, List[str]] = None,
110
  num_images_per_prompt: int = 1,
111
+ max_sequence_length: int = 512,
112
  device: Optional[torch.device] = None,
113
  dtype: Optional[torch.dtype] = None,
114
  ):
 
179
  "The following part of your input was truncated because CLIP can only handle sequences up to"
180
  f" {self.tokenizer_max_length} tokens: {removed_text}"
181
  )
182
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
183
 
184
  # Use pooled output of CLIPTextModel
185
  prompt_embeds = prompt_embeds.pooler_output
186
  prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
187
 
188
+ _, seq_len, _ = prompt_embeds.shape
189
 
190
  # duplicate text embeddings for each generation per prompt, using mps friendly method
191
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
192
  prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
193
 
194
  return prompt_embeds
 
274
  num_images_per_prompt=num_images_per_prompt,
275
  )
276
 
277
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
278
  prompt=negative_prompt_2,
279
  device=device,
280
  num_images_per_prompt=num_images_per_prompt,
281
  max_sequence_length=max_sequence_length,
282
  )
283
 
 
 
 
 
 
 
 
 
284
  if self.text_encoder is not None:
285
  if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
286
  # Retrieve the original scale by scaling back the LoRA layers
 
292
  unscale_lora_layers(self.text_encoder_2, lora_scale)
293
 
294
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
295
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
296
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
297
+ negative_text_ids = torch.zeros(batch_size, negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
 
 
 
 
298
 
299
+ return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids
 
 
 
300
 
301
  def check_inputs(
302
  self,
 
304
  prompt_2,
305
  height,
306
  width,
 
 
307
  prompt_embeds=None,
308
  negative_prompt_embeds=None,
309
  pooled_prompt_embeds=None,
 
337
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
338
  )
339
  if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
340
+ raise ValueError("Must provide `negative_pooled_prompt_embeds` when specifying `negative_prompt_embeds`.")
341
 
342
  if max_sequence_length is not None and max_sequence_length > 512:
343
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
 
350
 
351
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
352
 
353
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
354
  latent_image_ids = latent_image_ids.reshape(
355
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
356
  )
357
 
358
  return latent_image_ids.to(device=device, dtype=dtype)
 
378
  latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
379
 
380
  return latents
381
+
382
+ def prepare_latents(
383
+ self,
384
+ batch_size,
385
+ num_channels_latents,
386
+ height,
387
+ width,
388
+ dtype,
389
+ device,
390
+ generator,
391
+ latents=None,
392
+ ):
393
+ height = 2 * (int(height) // self.vae_scale_factor)
394
+ width = 2 * (int(width) // self.vae_scale_factor)
395
+
396
+ shape = (batch_size, num_channels_latents, height, width)
397
+
398
+ if latents is not None:
399
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
400
+ return latents.to(device=device, dtype=dtype), latent_image_ids
401
+
402
+ if isinstance(generator, list) and len(generator) != batch_size:
403
+ raise ValueError(
404
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
405
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
406
+ )
407
+
408
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
409
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
410
+
411
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
412
+
413
+ return latents, latent_image_ids
414
+
415
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
416
  def prepare_extra_step_kwargs(self, generator, eta):
417
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
 
459
  """
460
  self.vae.disable_tiling()
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  @property
463
  def guidance_scale(self):
464
  return self._guidance_scale
 
502
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
503
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
504
  output_type: Optional[str] = "pil",
505
+ cfg: Optional[bool] = True,
506
  return_dict: bool = True,
507
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
508
+ max_sequence_length: int = 512,
509
  **kwargs,
510
  ):
511
  height = height or self.default_sample_size * self.vae_scale_factor
 
517
  prompt_2,
518
  height,
519
  width,
 
 
520
  prompt_embeds=prompt_embeds,
521
  negative_prompt_embeds=negative_prompt_embeds,
522
  pooled_prompt_embeds=pooled_prompt_embeds,
 
527
  self._guidance_scale = guidance_scale
528
  self._joint_attention_kwargs = joint_attention_kwargs
529
  self._interrupt = False
530
+
 
 
531
  # 2. Define call parameters
532
  if prompt is not None and isinstance(prompt, str):
533
  batch_size = 1
 
547
  text_ids,
548
  negative_prompt_embeds,
549
  negative_pooled_prompt_embeds,
550
+ negative_text_ids,
551
  ) = self.encode_prompt(
552
  prompt=prompt,
553
  prompt_2=prompt_2,
 
566
 
567
  if self.do_classifier_free_guidance:
568
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
569
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
570
 
571
  # 4. Prepare latent variables
572
  num_channels_latents = self.transformer.config.in_channels // 4
 
576
  height,
577
  width,
578
  prompt_embeds.dtype,
 
579
  device,
580
  generator,
581
  latents,
 
584
  # 5. Prepare timesteps
585
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
586
  image_seq_len = latents.shape[1]
587
+ mu = calculate_shift(
588
+ image_seq_len,
589
+ self.scheduler.config.base_image_seq_len,
590
+ self.scheduler.config.max_image_seq_len,
591
+ self.scheduler.config.base_shift,
592
+ self.scheduler.config.max_shift,
593
+ )
594
  timesteps, num_inference_steps = prepare_timesteps(
595
  self.scheduler,
596
  num_inference_steps,
 
599
  sigmas,
600
  mu=mu,
601
  )
602
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
603
  self._num_timesteps = len(timesteps)
604
 
605
  # 6. Denoising loop
 
618
  else:
619
  guidance = None
620
 
621
+ noise_pred_text = self.transformer(
622
  hidden_states=latent_model_input,
623
  timestep=timestep / 1000,
624
  guidance=guidance,
 
629
  joint_attention_kwargs=self.joint_attention_kwargs,
630
  return_dict=False,
631
  )[0]
632
+ noise_pred_uncond = self.transformer(
633
+ hidden_states=latents,
634
+ timestep=timestep / 1000,
635
+ guidance=guidance,
636
+ pooled_projections=negative_pooled_prompt_embeds,
637
+ encoder_hidden_states=negative_prompt_embeds,
638
+ txt_ids=negative_text_ids,
639
+ img_ids=latent_image_ids,
640
+ joint_attention_kwargs=self.joint_attention_kwargs,
641
+ return_dict=False,
642
+ )[0]
643
 
644
  if self.do_classifier_free_guidance:
645
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
646
+ noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
647
+ else: noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
648
+
649
  # compute the previous noisy sample x_t -> x_t-1
650
  latents_dtype = latents.dtype
651
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]