Reverted hack to get sequential_cpu_offload working. Not perfect.
Browse files- pipeline.py +6 -67
pipeline.py
CHANGED
@@ -16,7 +16,7 @@ from diffusers import SchedulerMixin, StableDiffusionPipeline
|
|
16 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
17 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
18 |
from diffusers.utils import logging
|
19 |
-
|
20 |
|
21 |
try:
|
22 |
from diffusers.utils import PIL_INTERPOLATION
|
@@ -281,7 +281,6 @@ def get_weighted_text_embeddings(
|
|
281 |
skip_weighting (`bool`, *optional*, defaults to `False`):
|
282 |
Skip the weighting. When the parsing is skipped, it is forced True.
|
283 |
"""
|
284 |
-
unet_device = torch.device('cpu') if pipe.unet.device == torch.device('meta') else pipe.unet.device
|
285 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
286 |
if isinstance(prompt, str):
|
287 |
prompt = [prompt]
|
@@ -330,7 +329,7 @@ def get_weighted_text_embeddings(
|
|
330 |
no_boseos_middle=no_boseos_middle,
|
331 |
chunk_length=pipe.tokenizer.model_max_length,
|
332 |
)
|
333 |
-
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=
|
334 |
if uncond_prompt is not None:
|
335 |
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
336 |
uncond_tokens,
|
@@ -341,7 +340,7 @@ def get_weighted_text_embeddings(
|
|
341 |
no_boseos_middle=no_boseos_middle,
|
342 |
chunk_length=pipe.tokenizer.model_max_length,
|
343 |
)
|
344 |
-
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=
|
345 |
|
346 |
# get the embeddings
|
347 |
text_embeddings = get_unweighted_text_embeddings(
|
@@ -350,8 +349,7 @@ def get_weighted_text_embeddings(
|
|
350 |
pipe.tokenizer.model_max_length,
|
351 |
no_boseos_middle=no_boseos_middle,
|
352 |
)
|
353 |
-
|
354 |
-
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=unet_device)
|
355 |
if uncond_prompt is not None:
|
356 |
uncond_embeddings = get_unweighted_text_embeddings(
|
357 |
pipe,
|
@@ -359,8 +357,7 @@ def get_weighted_text_embeddings(
|
|
359 |
pipe.tokenizer.model_max_length,
|
360 |
no_boseos_middle=no_boseos_middle,
|
361 |
)
|
362 |
-
|
363 |
-
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=unet_device)
|
364 |
|
365 |
# assign weights to the prompts and normalize in the sense of mean
|
366 |
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
@@ -484,59 +481,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
484 |
if not hasattr(self, "vae_scale_factor"):
|
485 |
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
486 |
|
487 |
-
def enable_sequential_cpu_offload(self, gpu_id=0):
|
488 |
-
r"""
|
489 |
-
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
490 |
-
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
491 |
-
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
492 |
-
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
493 |
-
`enable_model_cpu_offload`, but performance is lower.
|
494 |
-
"""
|
495 |
-
if is_accelerate_available():
|
496 |
-
from accelerate import cpu_offload
|
497 |
-
else:
|
498 |
-
raise ImportError("Please install accelerate via `pip install accelerate`")
|
499 |
-
|
500 |
-
device = torch.device(f"cuda:{gpu_id}")
|
501 |
-
|
502 |
-
if self.device.type != "cpu":
|
503 |
-
self.to("cpu", silence_dtype_warnings=True)
|
504 |
-
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
505 |
-
|
506 |
-
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
507 |
-
cpu_offload(cpu_offloaded_model, device)
|
508 |
-
|
509 |
-
if self.safety_checker is not None:
|
510 |
-
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
511 |
-
|
512 |
-
def enable_model_cpu_offload(self, gpu_id=0):
|
513 |
-
r"""
|
514 |
-
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
515 |
-
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
516 |
-
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
517 |
-
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
518 |
-
"""
|
519 |
-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
520 |
-
from accelerate import cpu_offload_with_hook
|
521 |
-
else:
|
522 |
-
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
523 |
-
|
524 |
-
device = torch.device(f"cuda:{gpu_id}")
|
525 |
-
|
526 |
-
if self.device.type != "cpu":
|
527 |
-
self.to("cpu", silence_dtype_warnings=True)
|
528 |
-
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
529 |
-
|
530 |
-
hook = None
|
531 |
-
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
532 |
-
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
533 |
-
|
534 |
-
if self.safety_checker is not None:
|
535 |
-
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
536 |
-
|
537 |
-
# We'll offload the last model manually.
|
538 |
-
self.final_offload_hook = hook
|
539 |
-
|
540 |
@property
|
541 |
def _execution_device(self):
|
542 |
r"""
|
@@ -544,8 +488,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
544 |
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
545 |
hooks.
|
546 |
"""
|
547 |
-
|
548 |
-
if not hasattr(self.unet, "_hf_hook"):
|
549 |
return self.device
|
550 |
for module in self.unet.modules():
|
551 |
if (
|
@@ -915,10 +858,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
915 |
if output_type == "pil":
|
916 |
image = self.numpy_to_pil(image)
|
917 |
|
918 |
-
# 12. Offload last model to CPU
|
919 |
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
920 |
-
self.final_offload_hook.offload()
|
921 |
-
|
922 |
if not return_dict:
|
923 |
return image, has_nsfw_concept
|
924 |
|
|
|
16 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
17 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
18 |
from diffusers.utils import logging
|
19 |
+
|
20 |
|
21 |
try:
|
22 |
from diffusers.utils import PIL_INTERPOLATION
|
|
|
281 |
skip_weighting (`bool`, *optional*, defaults to `False`):
|
282 |
Skip the weighting. When the parsing is skipped, it is forced True.
|
283 |
"""
|
|
|
284 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
285 |
if isinstance(prompt, str):
|
286 |
prompt = [prompt]
|
|
|
329 |
no_boseos_middle=no_boseos_middle,
|
330 |
chunk_length=pipe.tokenizer.model_max_length,
|
331 |
)
|
332 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
333 |
if uncond_prompt is not None:
|
334 |
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
335 |
uncond_tokens,
|
|
|
340 |
no_boseos_middle=no_boseos_middle,
|
341 |
chunk_length=pipe.tokenizer.model_max_length,
|
342 |
)
|
343 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
344 |
|
345 |
# get the embeddings
|
346 |
text_embeddings = get_unweighted_text_embeddings(
|
|
|
349 |
pipe.tokenizer.model_max_length,
|
350 |
no_boseos_middle=no_boseos_middle,
|
351 |
)
|
352 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
|
|
353 |
if uncond_prompt is not None:
|
354 |
uncond_embeddings = get_unweighted_text_embeddings(
|
355 |
pipe,
|
|
|
357 |
pipe.tokenizer.model_max_length,
|
358 |
no_boseos_middle=no_boseos_middle,
|
359 |
)
|
360 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
|
|
361 |
|
362 |
# assign weights to the prompts and normalize in the sense of mean
|
363 |
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
|
|
481 |
if not hasattr(self, "vae_scale_factor"):
|
482 |
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
@property
|
485 |
def _execution_device(self):
|
486 |
r"""
|
|
|
488 |
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
489 |
hooks.
|
490 |
"""
|
491 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
|
|
492 |
return self.device
|
493 |
for module in self.unet.modules():
|
494 |
if (
|
|
|
858 |
if output_type == "pil":
|
859 |
image = self.numpy_to_pil(image)
|
860 |
|
|
|
|
|
|
|
|
|
861 |
if not return_dict:
|
862 |
return image, has_nsfw_concept
|
863 |
|