Update pipeline.py
Browse files- pipeline.py +120 -53
pipeline.py
CHANGED
@@ -5,14 +5,37 @@ from typing import Callable, List, Optional, Union
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
|
|
8 |
import PIL
|
9 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
10 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
11 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
12 |
-
from diffusers.utils import
|
|
|
13 |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
17 |
|
18 |
re_attention = re.compile(
|
@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
404 |
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
405 |
"""
|
406 |
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
|
429 |
def _encode_prompt(
|
430 |
self,
|
@@ -752,37 +823,33 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
752 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
753 |
|
754 |
# 8. Denoising loop
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
if callback is not None:
|
783 |
-
callback(i, t, latents)
|
784 |
-
if is_cancelled_callback is not None and is_cancelled_callback():
|
785 |
-
return None
|
786 |
|
787 |
# 9. Post-processing
|
788 |
image = self.decode_latents(latents)
|
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
8 |
+
import diffusers
|
9 |
import PIL
|
10 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
11 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
12 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
13 |
+
from diffusers.utils import deprecate, logging
|
14 |
+
from packaging import version
|
15 |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
16 |
|
17 |
|
18 |
+
try:
|
19 |
+
from diffusers.utils import PIL_INTERPOLATION
|
20 |
+
except ImportError:
|
21 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
22 |
+
PIL_INTERPOLATION = {
|
23 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
24 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
25 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
26 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
27 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
28 |
+
}
|
29 |
+
else:
|
30 |
+
PIL_INTERPOLATION = {
|
31 |
+
"linear": PIL.Image.LINEAR,
|
32 |
+
"bilinear": PIL.Image.BILINEAR,
|
33 |
+
"bicubic": PIL.Image.BICUBIC,
|
34 |
+
"lanczos": PIL.Image.LANCZOS,
|
35 |
+
"nearest": PIL.Image.NEAREST,
|
36 |
+
}
|
37 |
+
# ------------------------------------------------------------------------------
|
38 |
+
|
39 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
|
41 |
re_attention = re.compile(
|
|
|
427 |
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
428 |
"""
|
429 |
|
430 |
+
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
431 |
+
|
432 |
+
def __init__(
|
433 |
+
self,
|
434 |
+
vae: AutoencoderKL,
|
435 |
+
text_encoder: CLIPTextModel,
|
436 |
+
tokenizer: CLIPTokenizer,
|
437 |
+
unet: UNet2DConditionModel,
|
438 |
+
scheduler: SchedulerMixin,
|
439 |
+
safety_checker: StableDiffusionSafetyChecker,
|
440 |
+
feature_extractor: CLIPFeatureExtractor,
|
441 |
+
requires_safety_checker: bool = True,
|
442 |
+
):
|
443 |
+
super().__init__(
|
444 |
+
vae=vae,
|
445 |
+
text_encoder=text_encoder,
|
446 |
+
tokenizer=tokenizer,
|
447 |
+
unet=unet,
|
448 |
+
scheduler=scheduler,
|
449 |
+
safety_checker=safety_checker,
|
450 |
+
feature_extractor=feature_extractor,
|
451 |
+
requires_safety_checker=requires_safety_checker,
|
452 |
+
)
|
453 |
+
self.__init__additional__()
|
454 |
+
|
455 |
+
else:
|
456 |
+
|
457 |
+
def __init__(
|
458 |
+
self,
|
459 |
+
vae: AutoencoderKL,
|
460 |
+
text_encoder: CLIPTextModel,
|
461 |
+
tokenizer: CLIPTokenizer,
|
462 |
+
unet: UNet2DConditionModel,
|
463 |
+
scheduler: SchedulerMixin,
|
464 |
+
safety_checker: StableDiffusionSafetyChecker,
|
465 |
+
feature_extractor: CLIPFeatureExtractor,
|
466 |
+
):
|
467 |
+
super().__init__(
|
468 |
+
vae=vae,
|
469 |
+
text_encoder=text_encoder,
|
470 |
+
tokenizer=tokenizer,
|
471 |
+
unet=unet,
|
472 |
+
scheduler=scheduler,
|
473 |
+
safety_checker=safety_checker,
|
474 |
+
feature_extractor=feature_extractor,
|
475 |
+
)
|
476 |
+
self.__init__additional__()
|
477 |
+
|
478 |
+
def __init__additional__(self):
|
479 |
+
if not hasattr(self, "vae_scale_factor"):
|
480 |
+
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
481 |
+
|
482 |
+
@property
|
483 |
+
def _execution_device(self):
|
484 |
+
r"""
|
485 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
486 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
487 |
+
hooks.
|
488 |
+
"""
|
489 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
490 |
+
return self.device
|
491 |
+
for module in self.unet.modules():
|
492 |
+
if (
|
493 |
+
hasattr(module, "_hf_hook")
|
494 |
+
and hasattr(module._hf_hook, "execution_device")
|
495 |
+
and module._hf_hook.execution_device is not None
|
496 |
+
):
|
497 |
+
return torch.device(module._hf_hook.execution_device)
|
498 |
+
return self.device
|
499 |
|
500 |
def _encode_prompt(
|
501 |
self,
|
|
|
823 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
824 |
|
825 |
# 8. Denoising loop
|
826 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
827 |
+
# expand the latents if we are doing classifier free guidance
|
828 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
829 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
830 |
+
|
831 |
+
# predict the noise residual
|
832 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
833 |
+
|
834 |
+
# perform guidance
|
835 |
+
if do_classifier_free_guidance:
|
836 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
837 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
838 |
+
|
839 |
+
# compute the previous noisy sample x_t -> x_t-1
|
840 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
841 |
+
|
842 |
+
if mask is not None:
|
843 |
+
# masking
|
844 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
845 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
846 |
+
|
847 |
+
# call the callback, if provided
|
848 |
+
if i % callback_steps == 0:
|
849 |
+
if callback is not None:
|
850 |
+
callback(i, t, latents)
|
851 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
852 |
+
return None
|
|
|
|
|
|
|
|
|
853 |
|
854 |
# 9. Post-processing
|
855 |
image = self.decode_latents(latents)
|