AlanB commited on
Commit
4afe117
1 Parent(s): 26206b5

Reverted hack to get sequential_cpu_offload working. Not perfect.

Browse files
Files changed (1) hide show
  1. pipeline.py +6 -67
pipeline.py CHANGED
@@ -16,7 +16,7 @@ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
  from diffusers.utils import logging
19
- from diffusers.utils import deprecate, is_accelerate_available, is_accelerate_version
20
 
21
  try:
22
  from diffusers.utils import PIL_INTERPOLATION
@@ -281,7 +281,6 @@ def get_weighted_text_embeddings(
281
  skip_weighting (`bool`, *optional*, defaults to `False`):
282
  Skip the weighting. When the parsing is skipped, it is forced True.
283
  """
284
- unet_device = torch.device('cpu') if pipe.unet.device == torch.device('meta') else pipe.unet.device
285
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
286
  if isinstance(prompt, str):
287
  prompt = [prompt]
@@ -330,7 +329,7 @@ def get_weighted_text_embeddings(
330
  no_boseos_middle=no_boseos_middle,
331
  chunk_length=pipe.tokenizer.model_max_length,
332
  )
333
- prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=unet_device)
334
  if uncond_prompt is not None:
335
  uncond_tokens, uncond_weights = pad_tokens_and_weights(
336
  uncond_tokens,
@@ -341,7 +340,7 @@ def get_weighted_text_embeddings(
341
  no_boseos_middle=no_boseos_middle,
342
  chunk_length=pipe.tokenizer.model_max_length,
343
  )
344
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=unet_device)
345
 
346
  # get the embeddings
347
  text_embeddings = get_unweighted_text_embeddings(
@@ -350,8 +349,7 @@ def get_weighted_text_embeddings(
350
  pipe.tokenizer.model_max_length,
351
  no_boseos_middle=no_boseos_middle,
352
  )
353
- text_embeddings = text_embeddings.to(device=unet_device)
354
- prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=unet_device)
355
  if uncond_prompt is not None:
356
  uncond_embeddings = get_unweighted_text_embeddings(
357
  pipe,
@@ -359,8 +357,7 @@ def get_weighted_text_embeddings(
359
  pipe.tokenizer.model_max_length,
360
  no_boseos_middle=no_boseos_middle,
361
  )
362
- uncond_embeddings = uncond_embeddings.to(device=unet_device)
363
- uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=unet_device)
364
 
365
  # assign weights to the prompts and normalize in the sense of mean
366
  # TODO: should we normalize by chunk or in a whole (current implementation)?
@@ -484,59 +481,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
484
  if not hasattr(self, "vae_scale_factor"):
485
  setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
486
 
487
- def enable_sequential_cpu_offload(self, gpu_id=0):
488
- r"""
489
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
490
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
491
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
492
- Note that offloading happens on a submodule basis. Memory savings are higher than with
493
- `enable_model_cpu_offload`, but performance is lower.
494
- """
495
- if is_accelerate_available():
496
- from accelerate import cpu_offload
497
- else:
498
- raise ImportError("Please install accelerate via `pip install accelerate`")
499
-
500
- device = torch.device(f"cuda:{gpu_id}")
501
-
502
- if self.device.type != "cpu":
503
- self.to("cpu", silence_dtype_warnings=True)
504
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
505
-
506
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
507
- cpu_offload(cpu_offloaded_model, device)
508
-
509
- if self.safety_checker is not None:
510
- cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
511
-
512
- def enable_model_cpu_offload(self, gpu_id=0):
513
- r"""
514
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
515
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
516
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
517
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
518
- """
519
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
520
- from accelerate import cpu_offload_with_hook
521
- else:
522
- raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
523
-
524
- device = torch.device(f"cuda:{gpu_id}")
525
-
526
- if self.device.type != "cpu":
527
- self.to("cpu", silence_dtype_warnings=True)
528
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
529
-
530
- hook = None
531
- for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
532
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
533
-
534
- if self.safety_checker is not None:
535
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
536
-
537
- # We'll offload the last model manually.
538
- self.final_offload_hook = hook
539
-
540
  @property
541
  def _execution_device(self):
542
  r"""
@@ -544,8 +488,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
544
  `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
545
  hooks.
546
  """
547
- #if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
548
- if not hasattr(self.unet, "_hf_hook"):
549
  return self.device
550
  for module in self.unet.modules():
551
  if (
@@ -915,10 +858,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
915
  if output_type == "pil":
916
  image = self.numpy_to_pil(image)
917
 
918
- # 12. Offload last model to CPU
919
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
920
- self.final_offload_hook.offload()
921
-
922
  if not return_dict:
923
  return image, has_nsfw_concept
924
 
 
16
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
  from diffusers.utils import logging
19
+
20
 
21
  try:
22
  from diffusers.utils import PIL_INTERPOLATION
 
281
  skip_weighting (`bool`, *optional*, defaults to `False`):
282
  Skip the weighting. When the parsing is skipped, it is forced True.
283
  """
 
284
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
285
  if isinstance(prompt, str):
286
  prompt = [prompt]
 
329
  no_boseos_middle=no_boseos_middle,
330
  chunk_length=pipe.tokenizer.model_max_length,
331
  )
332
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
333
  if uncond_prompt is not None:
334
  uncond_tokens, uncond_weights = pad_tokens_and_weights(
335
  uncond_tokens,
 
340
  no_boseos_middle=no_boseos_middle,
341
  chunk_length=pipe.tokenizer.model_max_length,
342
  )
343
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
344
 
345
  # get the embeddings
346
  text_embeddings = get_unweighted_text_embeddings(
 
349
  pipe.tokenizer.model_max_length,
350
  no_boseos_middle=no_boseos_middle,
351
  )
352
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
 
353
  if uncond_prompt is not None:
354
  uncond_embeddings = get_unweighted_text_embeddings(
355
  pipe,
 
357
  pipe.tokenizer.model_max_length,
358
  no_boseos_middle=no_boseos_middle,
359
  )
360
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
 
361
 
362
  # assign weights to the prompts and normalize in the sense of mean
363
  # TODO: should we normalize by chunk or in a whole (current implementation)?
 
481
  if not hasattr(self, "vae_scale_factor"):
482
  setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  @property
485
  def _execution_device(self):
486
  r"""
 
488
  `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
489
  hooks.
490
  """
491
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
 
492
  return self.device
493
  for module in self.unet.modules():
494
  if (
 
858
  if output_type == "pil":
859
  image = self.numpy_to_pil(image)
860
 
 
 
 
 
861
  if not return_dict:
862
  return image, has_nsfw_concept
863