skytnt commited on
Commit
69d52d4
1 Parent(s): 92ddf35

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +233 -346
pipeline.py CHANGED
@@ -459,14 +459,17 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
459
  self.enable_attention_slicing(None)
460
 
461
  @torch.no_grad()
462
- def text2img(
463
  self,
464
  prompt: Union[str, List[str]],
 
 
 
465
  height: int = 512,
466
  width: int = 512,
467
  num_inference_steps: int = 50,
468
  guidance_scale: float = 7.5,
469
- negative_prompt: Optional[Union[str, List[str]]] = None,
470
  num_images_per_prompt: Optional[int] = 1,
471
  eta: float = 0.0,
472
  generator: Optional[torch.Generator] = None,
@@ -484,6 +487,17 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
484
  Args:
485
  prompt (`str` or `List[str]`):
486
  The prompt or prompts to guide the image generation.
 
 
 
 
 
 
 
 
 
 
 
487
  height (`int`, *optional*, defaults to 512):
488
  The height in pixels of the generated image.
489
  width (`int`, *optional*, defaults to 512):
@@ -497,9 +511,12 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
497
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
498
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
499
  usually at the expense of lower image quality.
500
- negative_prompt (`str` or `List[str]`, *optional*):
501
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
502
- if `guidance_scale` is less than `1`).
 
 
 
503
  num_images_per_prompt (`int`, *optional*, defaults to 1):
504
  The number of images to generate per prompt.
505
  eta (`float`, *optional*, defaults to 0.0):
@@ -542,6 +559,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
542
  else:
543
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
544
 
 
 
 
545
  if height % 8 != 0 or width % 8 != 0:
546
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
547
 
@@ -586,35 +606,81 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
586
  **kwargs
587
  )
588
 
589
- # get the initial random noise unless the user supplied it
 
590
 
591
- # Unlike in other pipelines, latents need to be generated in the target device
592
- # for 1-to-1 results reproducibility with the CompVis implementation.
593
- # However this currently doesn't work in `mps`.
594
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
595
  latents_dtype = text_embeddings.dtype
596
- if latents is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  if self.device.type == "mps":
598
  # randn does not exist on mps
599
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
600
  self.device
601
  )
602
  else:
603
- latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
604
- else:
605
- if latents.shape != latents_shape:
606
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
607
- latents = latents.to(self.device)
608
-
609
- # set timesteps
610
- self.scheduler.set_timesteps(num_inference_steps)
611
-
612
- # Some schedulers like PNDM have timesteps as arrays
613
- # It's more optimized to move all timesteps to correct device beforehand
614
- timesteps_tensor = self.scheduler.timesteps.to(self.device)
615
 
616
- # scale the initial noise by the standard deviation required by the scheduler
617
- latents = latents * self.scheduler.init_noise_sigma
618
 
619
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
620
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -625,7 +691,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
625
  if accepts_eta:
626
  extra_step_kwargs["eta"] = eta
627
 
628
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
629
  # expand the latents if we are doing classifier free guidance
630
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
631
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -641,6 +707,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
641
  # compute the previous noisy sample x_t -> x_t-1
642
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
643
 
 
 
 
 
 
644
  # call the callback, if provided
645
  if callback is not None and i % callback_steps == 0:
646
  callback(i, t, latents)
@@ -671,15 +742,106 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
671
 
672
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
673
 
674
- @torch.no_grad()
675
- def img2img(
676
  self,
677
  prompt: Union[str, List[str]],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  init_image: Union[torch.FloatTensor, PIL.Image.Image],
 
 
679
  strength: float = 0.8,
680
  num_inference_steps: Optional[int] = 50,
681
  guidance_scale: Optional[float] = 7.5,
682
- negative_prompt: Optional[Union[str, List[str]]] = None,
683
  num_images_per_prompt: Optional[int] = 1,
684
  eta: Optional[float] = 0.0,
685
  generator: Optional[torch.Generator] = None,
@@ -691,14 +853,16 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
691
  **kwargs,
692
  ):
693
  r"""
694
- Function invoked when calling the pipeline for generation.
695
-
696
  Args:
697
- prompt (`str` or `List[str]`):
698
- The prompt or prompts to guide the image generation.
699
  init_image (`torch.FloatTensor` or `PIL.Image.Image`):
700
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
701
  process.
 
 
 
 
 
702
  strength (`float`, *optional*, defaults to 0.8):
703
  Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
704
  `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
@@ -714,9 +878,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
714
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
715
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
716
  usually at the expense of lower image quality.
717
- negative_prompt (`str` or `List[str]`, *optional*):
718
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
719
- if `guidance_scale` is less than `1`).
720
  num_images_per_prompt (`int`, *optional*, defaults to 1):
721
  The number of images to generate per prompt.
722
  eta (`float`, *optional*, defaults to 0.0):
@@ -739,7 +900,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
739
  callback_steps (`int`, *optional*, defaults to 1):
740
  The frequency at which the `callback` function will be called. If not specified, the callback will be
741
  called at every step.
742
-
743
  Returns:
744
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
745
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@@ -747,169 +907,33 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
747
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
748
  (nsfw) content, according to the `safety_checker`.
749
  """
750
- if isinstance(prompt, str):
751
- batch_size = 1
752
- elif isinstance(prompt, list):
753
- batch_size = len(prompt)
754
- else:
755
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
756
-
757
- if strength < 0 or strength > 1:
758
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
759
-
760
- if (callback_steps is None) or (
761
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
762
- ):
763
- raise ValueError(
764
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
765
- f" {type(callback_steps)}."
766
- )
767
-
768
- # set timesteps
769
- self.scheduler.set_timesteps(num_inference_steps)
770
-
771
- if isinstance(init_image, PIL.Image.Image):
772
- init_image = preprocess_image(init_image)
773
-
774
- # get prompt text embeddings
775
-
776
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
777
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
778
- # corresponds to doing no classifier free guidance.
779
- do_classifier_free_guidance = guidance_scale > 1.0
780
- # get unconditional embeddings for classifier free guidance
781
- uncond_tokens = [""]
782
- if do_classifier_free_guidance:
783
- if type(prompt) is not type(negative_prompt):
784
- raise TypeError(
785
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
786
- f" {type(prompt)}."
787
- )
788
- elif isinstance(negative_prompt, str):
789
- uncond_tokens = [negative_prompt]
790
- elif batch_size != len(negative_prompt):
791
- raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
792
- else:
793
- uncond_tokens = negative_prompt
794
-
795
- text_embeddings = get_weighted_text_embeddings(
796
- pipe=self,
797
  prompt=prompt,
798
- uncond_prompt=uncond_tokens if do_classifier_free_guidance else None,
 
 
 
 
 
 
 
799
  max_embeddings_multiples=max_embeddings_multiples,
 
 
 
 
800
  **kwargs
801
  )
802
 
803
- # encode the init image into latents and scale the latents
804
- latents_dtype = text_embeddings.dtype
805
- init_image = init_image.to(device=self.device, dtype=latents_dtype)
806
- init_latent_dist = self.vae.encode(init_image).latent_dist
807
- init_latents = init_latent_dist.sample(generator=generator)
808
- init_latents = 0.18215 * init_latents
809
-
810
- if isinstance(prompt, str):
811
- prompt = [prompt]
812
- if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
813
- # expand init_latents for batch_size
814
- deprecation_message = (
815
- f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
816
- " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
817
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
818
- " your script to pass as many init images as text prompts to suppress this warning."
819
- )
820
- deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
821
- additional_image_per_prompt = len(prompt) // init_latents.shape[0]
822
- init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
823
- elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
824
- raise ValueError(
825
- f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
826
- )
827
- else:
828
- init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
829
-
830
- # get the original timestep using init_timestep
831
- offset = self.scheduler.config.get("steps_offset", 0)
832
- init_timestep = int(num_inference_steps * strength) + offset
833
- init_timestep = min(init_timestep, num_inference_steps)
834
-
835
- timesteps = self.scheduler.timesteps[-init_timestep]
836
- timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
837
-
838
- # add noise to latents using the timesteps
839
- noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
840
- init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
841
-
842
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
843
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
844
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
845
- # and should be between [0, 1]
846
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
847
- extra_step_kwargs = {}
848
- if accepts_eta:
849
- extra_step_kwargs["eta"] = eta
850
-
851
- latents = init_latents
852
-
853
- t_start = max(num_inference_steps - init_timestep + offset, 0)
854
-
855
- # Some schedulers like PNDM have timesteps as arrays
856
- # It's more optimized to move all timesteps to correct device beforehand
857
- timesteps = self.scheduler.timesteps[t_start:].to(self.device)
858
-
859
- for i, t in enumerate(self.progress_bar(timesteps)):
860
- # expand the latents if we are doing classifier free guidance
861
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
862
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
863
-
864
- # predict the noise residual
865
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
866
-
867
- # perform guidance
868
- if do_classifier_free_guidance:
869
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
870
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
871
-
872
- # compute the previous noisy sample x_t -> x_t-1
873
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
874
-
875
- # call the callback, if provided
876
- if callback is not None and i % callback_steps == 0:
877
- callback(i, t, latents)
878
-
879
- latents = 1 / 0.18215 * latents
880
- image = self.vae.decode(latents).sample
881
-
882
- image = (image / 2 + 0.5).clamp(0, 1)
883
- image = image.cpu().permute(0, 2, 3, 1).numpy()
884
-
885
- if self.safety_checker is not None:
886
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
887
- self.device
888
- )
889
- image, has_nsfw_concept = self.safety_checker(
890
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
891
- )
892
- else:
893
- has_nsfw_concept = None
894
-
895
- if output_type == "pil":
896
- image = self.numpy_to_pil(image)
897
-
898
- if not return_dict:
899
- return (image, has_nsfw_concept)
900
-
901
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
902
-
903
- @torch.no_grad()
904
  def inpaint(
905
  self,
906
- prompt: Union[str, List[str]],
907
  init_image: Union[torch.FloatTensor, PIL.Image.Image],
908
  mask_image: Union[torch.FloatTensor, PIL.Image.Image],
 
 
909
  strength: float = 0.8,
910
  num_inference_steps: Optional[int] = 50,
911
  guidance_scale: Optional[float] = 7.5,
912
- negative_prompt: Optional[Union[str, List[str]]] = None,
913
  num_images_per_prompt: Optional[int] = 1,
914
  eta: Optional[float] = 0.0,
915
  generator: Optional[torch.Generator] = None,
@@ -921,11 +945,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
921
  **kwargs,
922
  ):
923
  r"""
924
- Function invoked when calling the pipeline for generation.
925
-
926
  Args:
927
- prompt (`str` or `List[str]`):
928
- The prompt or prompts to guide the image generation.
929
  init_image (`torch.FloatTensor` or `PIL.Image.Image`):
930
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
931
  process. This is the image whose masked region will be inpainted.
@@ -934,6 +955,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
934
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
935
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
936
  contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
 
 
 
 
 
937
  strength (`float`, *optional*, defaults to 0.8):
938
  Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
939
  is 1, the denoising process will be run on the masked area for the full number of iterations specified
@@ -948,9 +974,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
948
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
949
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
950
  usually at the expense of lower image quality.
951
- negative_prompt (`str` or `List[str]`, *optional*):
952
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
953
- if `guidance_scale` is less than `1`).
954
  num_images_per_prompt (`int`, *optional*, defaults to 1):
955
  The number of images to generate per prompt.
956
  eta (`float`, *optional*, defaults to 0.0):
@@ -973,7 +996,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
973
  callback_steps (`int`, *optional*, defaults to 1):
974
  The frequency at which the `callback` function will be called. If not specified, the callback will be
975
  called at every step.
976
-
977
  Returns:
978
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
979
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@@ -981,156 +1003,21 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
981
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
982
  (nsfw) content, according to the `safety_checker`.
983
  """
984
- if isinstance(prompt, str):
985
- batch_size = 1
986
- elif isinstance(prompt, list):
987
- batch_size = len(prompt)
988
- else:
989
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
990
-
991
- if strength < 0 or strength > 1:
992
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
993
-
994
- if (callback_steps is None) or (
995
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
996
- ):
997
- raise ValueError(
998
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
999
- f" {type(callback_steps)}."
1000
- )
1001
-
1002
- # set timesteps
1003
- self.scheduler.set_timesteps(num_inference_steps)
1004
-
1005
- # get prompt text embeddings
1006
-
1007
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1008
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1009
- # corresponds to doing no classifier free guidance.
1010
- do_classifier_free_guidance = guidance_scale > 1.0
1011
- # get unconditional embeddings for classifier free guidance
1012
- uncond_tokens = [""]
1013
- if do_classifier_free_guidance:
1014
- if type(prompt) is not type(negative_prompt):
1015
- raise TypeError(
1016
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1017
- f" {type(prompt)}."
1018
- )
1019
- elif isinstance(negative_prompt, str):
1020
- uncond_tokens = [negative_prompt]
1021
- elif batch_size != len(negative_prompt):
1022
- raise ValueError(
1023
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1024
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1025
- " the batch size of `prompt`."
1026
- )
1027
- else:
1028
- uncond_tokens = negative_prompt
1029
-
1030
- text_embeddings = get_weighted_text_embeddings(
1031
- pipe=self,
1032
  prompt=prompt,
1033
- uncond_prompt=uncond_tokens if do_classifier_free_guidance else None,
 
 
 
 
 
 
 
 
1034
  max_embeddings_multiples=max_embeddings_multiples,
 
 
 
 
1035
  **kwargs
1036
  )
1037
-
1038
- # preprocess image
1039
- if not isinstance(init_image, torch.FloatTensor):
1040
- init_image = preprocess_image(init_image)
1041
-
1042
- # encode the init image into latents and scale the latents
1043
- latents_dtype = text_embeddings.dtype
1044
- init_image = init_image.to(device=self.device, dtype=latents_dtype)
1045
- init_latent_dist = self.vae.encode(init_image).latent_dist
1046
- init_latents = init_latent_dist.sample(generator=generator)
1047
- init_latents = 0.18215 * init_latents
1048
-
1049
- # Expand init_latents for batch_size and num_images_per_prompt
1050
- init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
1051
- init_latents_orig = init_latents
1052
-
1053
- # preprocess mask
1054
- if not isinstance(mask_image, torch.FloatTensor):
1055
- mask_image = preprocess_mask(mask_image)
1056
- mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
1057
- mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
1058
-
1059
- # check sizes
1060
- if not mask.shape == init_latents.shape:
1061
- raise ValueError("The mask and init_image should be the same size!")
1062
-
1063
- # get the original timestep using init_timestep
1064
- offset = self.scheduler.config.get("steps_offset", 0)
1065
- init_timestep = int(num_inference_steps * strength) + offset
1066
- init_timestep = min(init_timestep, num_inference_steps)
1067
-
1068
- timesteps = self.scheduler.timesteps[-init_timestep]
1069
- timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
1070
-
1071
- # add noise to latents using the timesteps
1072
- noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
1073
- init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
1074
-
1075
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1076
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1077
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1078
- # and should be between [0, 1]
1079
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
1080
- extra_step_kwargs = {}
1081
- if accepts_eta:
1082
- extra_step_kwargs["eta"] = eta
1083
-
1084
- latents = init_latents
1085
-
1086
- t_start = max(num_inference_steps - init_timestep + offset, 0)
1087
-
1088
- # Some schedulers like PNDM have timesteps as arrays
1089
- # It's more optimized to move all timesteps to correct device beforehand
1090
- timesteps = self.scheduler.timesteps[t_start:].to(self.device)
1091
-
1092
- for i, t in enumerate(self.progress_bar(timesteps)):
1093
- # expand the latents if we are doing classifier free guidance
1094
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1095
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1096
-
1097
- # predict the noise residual
1098
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
1099
-
1100
- # perform guidance
1101
- if do_classifier_free_guidance:
1102
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1103
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1104
-
1105
- # compute the previous noisy sample x_t -> x_t-1
1106
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1107
- # masking
1108
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1109
-
1110
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
1111
-
1112
- # call the callback, if provided
1113
- if callback is not None and i % callback_steps == 0:
1114
- callback(i, t, latents)
1115
-
1116
- latents = 1 / 0.18215 * latents
1117
- image = self.vae.decode(latents).sample
1118
-
1119
- image = (image / 2 + 0.5).clamp(0, 1)
1120
- image = image.cpu().permute(0, 2, 3, 1).numpy()
1121
-
1122
- if self.safety_checker is not None:
1123
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
1124
- self.device
1125
- )
1126
- image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
1127
- else:
1128
- has_nsfw_concept = None
1129
-
1130
- if output_type == "pil":
1131
- image = self.numpy_to_pil(image)
1132
-
1133
- if not return_dict:
1134
- return (image, has_nsfw_concept)
1135
-
1136
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
459
  self.enable_attention_slicing(None)
460
 
461
  @torch.no_grad()
462
+ def __call__(
463
  self,
464
  prompt: Union[str, List[str]],
465
+ negative_prompt: Optional[Union[str, List[str]]] = None,
466
+ init_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
467
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
468
  height: int = 512,
469
  width: int = 512,
470
  num_inference_steps: int = 50,
471
  guidance_scale: float = 7.5,
472
+ strength: float = 0.8,
473
  num_images_per_prompt: Optional[int] = 1,
474
  eta: float = 0.0,
475
  generator: Optional[torch.Generator] = None,
 
487
  Args:
488
  prompt (`str` or `List[str]`):
489
  The prompt or prompts to guide the image generation.
490
+ negative_prompt (`str` or `List[str]`, *optional*):
491
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
492
+ if `guidance_scale` is less than `1`).
493
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
494
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
495
+ process.
496
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
497
+ `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
498
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
499
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
500
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
501
  height (`int`, *optional*, defaults to 512):
502
  The height in pixels of the generated image.
503
  width (`int`, *optional*, defaults to 512):
 
511
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
512
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
513
  usually at the expense of lower image quality.
514
+ strength (`float`, *optional*, defaults to 0.8):
515
+ Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
516
+ `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
517
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
518
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
519
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
520
  num_images_per_prompt (`int`, *optional*, defaults to 1):
521
  The number of images to generate per prompt.
522
  eta (`float`, *optional*, defaults to 0.0):
 
559
  else:
560
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
561
 
562
+ if strength < 0 or strength > 1:
563
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
564
+
565
  if height % 8 != 0 or width % 8 != 0:
566
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
567
 
 
606
  **kwargs
607
  )
608
 
609
+ # set timesteps
610
+ self.scheduler.set_timesteps(num_inference_steps)
611
 
 
 
 
 
612
  latents_dtype = text_embeddings.dtype
613
+ init_latents_orig = None
614
+ mask = None
615
+ noise = None
616
+
617
+ if init_image is None:
618
+ # get the initial random noise unless the user supplied it
619
+
620
+ # Unlike in other pipelines, latents need to be generated in the target device
621
+ # for 1-to-1 results reproducibility with the CompVis implementation.
622
+ # However this currently doesn't work in `mps`.
623
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
624
+
625
+ if latents is None:
626
+ if self.device.type == "mps":
627
+ # randn does not exist on mps
628
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
629
+ self.device
630
+ )
631
+ else:
632
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
633
+ else:
634
+ if latents.shape != latents_shape:
635
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
636
+ latents = latents.to(self.device)
637
+
638
+ timesteps = self.scheduler.timesteps.to(self.device)
639
+
640
+ # scale the initial noise by the standard deviation required by the scheduler
641
+ latents = latents * self.scheduler.init_noise_sigma
642
+ else:
643
+ if isinstance(init_image, PIL.Image.Image):
644
+ init_image = preprocess_image(init_image)
645
+ # encode the init image into latents and scale the latents
646
+ init_image = init_image.to(device=self.device, dtype=latents_dtype)
647
+ init_latent_dist = self.vae.encode(init_image).latent_dist
648
+ init_latents = init_latent_dist.sample(generator=generator)
649
+ init_latents = 0.18215 * init_latents
650
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
651
+ init_latents_orig = init_latents
652
+
653
+ # preprocess mask
654
+ if mask_image is not None:
655
+ if isinstance(mask_image, PIL.Image.Image):
656
+ mask_image = preprocess_mask(mask_image)
657
+ mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
658
+ mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
659
+
660
+ # check sizes
661
+ if not mask.shape == init_latents.shape:
662
+ raise ValueError("The mask and init_image should be the same size!")
663
+
664
+ # get the original timestep using init_timestep
665
+ offset = self.scheduler.config.get("steps_offset", 0)
666
+ init_timestep = int(num_inference_steps * strength) + offset
667
+ init_timestep = min(init_timestep, num_inference_steps)
668
+
669
+ timesteps = self.scheduler.timesteps[-init_timestep]
670
+ timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
671
+
672
+ # add noise to latents using the timesteps
673
  if self.device.type == "mps":
674
  # randn does not exist on mps
675
+ noise = torch.randn(init_latents.shape, generator=generator, device="cpu", dtype=latents_dtype).to(
676
  self.device
677
  )
678
  else:
679
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
680
+ latents = self.scheduler.add_noise(init_latents, noise, timesteps)
 
 
 
 
 
 
 
 
 
 
681
 
682
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
683
+ timesteps = self.scheduler.timesteps[t_start:].to(self.device)
684
 
685
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
686
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
 
691
  if accepts_eta:
692
  extra_step_kwargs["eta"] = eta
693
 
694
+ for i, t in enumerate(self.progress_bar(timesteps)):
695
  # expand the latents if we are doing classifier free guidance
696
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
697
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
707
  # compute the previous noisy sample x_t -> x_t-1
708
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
709
 
710
+ if mask is not None:
711
+ # masking
712
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
713
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
714
+
715
  # call the callback, if provided
716
  if callback is not None and i % callback_steps == 0:
717
  callback(i, t, latents)
 
742
 
743
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
744
 
745
+ def text2img(
 
746
  self,
747
  prompt: Union[str, List[str]],
748
+ negative_prompt: Optional[Union[str, List[str]]] = None,
749
+ height: int = 512,
750
+ width: int = 512,
751
+ num_inference_steps: int = 50,
752
+ guidance_scale: float = 7.5,
753
+ num_images_per_prompt: Optional[int] = 1,
754
+ eta: float = 0.0,
755
+ generator: Optional[torch.Generator] = None,
756
+ latents: Optional[torch.FloatTensor] = None,
757
+ max_embeddings_multiples: Optional[int] = 3,
758
+ output_type: Optional[str] = "pil",
759
+ return_dict: bool = True,
760
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
761
+ callback_steps: Optional[int] = 1,
762
+ **kwargs,
763
+ ):
764
+ r"""
765
+ Function for text-to-image generation.
766
+ Args:
767
+ prompt (`str` or `List[str]`):
768
+ The prompt or prompts to guide the image generation.
769
+ negative_prompt (`str` or `List[str]`, *optional*):
770
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
771
+ if `guidance_scale` is less than `1`).
772
+ height (`int`, *optional*, defaults to 512):
773
+ The height in pixels of the generated image.
774
+ width (`int`, *optional*, defaults to 512):
775
+ The width in pixels of the generated image.
776
+ num_inference_steps (`int`, *optional*, defaults to 50):
777
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
778
+ expense of slower inference.
779
+ guidance_scale (`float`, *optional*, defaults to 7.5):
780
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
781
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
782
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
783
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
784
+ usually at the expense of lower image quality.
785
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
786
+ The number of images to generate per prompt.
787
+ eta (`float`, *optional*, defaults to 0.0):
788
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
789
+ [`schedulers.DDIMScheduler`], will be ignored for others.
790
+ generator (`torch.Generator`, *optional*):
791
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
792
+ deterministic.
793
+ latents (`torch.FloatTensor`, *optional*):
794
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
795
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
796
+ tensor will ge generated by sampling using the supplied random `generator`.
797
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
798
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
799
+ output_type (`str`, *optional*, defaults to `"pil"`):
800
+ The output format of the generate image. Choose between
801
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
802
+ return_dict (`bool`, *optional*, defaults to `True`):
803
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
804
+ plain tuple.
805
+ callback (`Callable`, *optional*):
806
+ A function that will be called every `callback_steps` steps during inference. The function will be
807
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
808
+ callback_steps (`int`, *optional*, defaults to 1):
809
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
810
+ called at every step.
811
+ Returns:
812
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
813
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
814
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
815
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
816
+ (nsfw) content, according to the `safety_checker`.
817
+ """
818
+ return self.__call__(
819
+ prompt=prompt,
820
+ negative_prompt=negative_prompt,
821
+ height=height,
822
+ width=width,
823
+ num_inference_steps=num_inference_steps,
824
+ guidance_scale=guidance_scale,
825
+ num_images_per_prompt=num_images_per_prompt,
826
+ eta=eta,
827
+ generator=generator,
828
+ latents=latents,
829
+ max_embeddings_multiples=max_embeddings_multiples,
830
+ output_type=output_type,
831
+ return_dict=return_dict,
832
+ callback=callback,
833
+ callback_steps=callback_steps,
834
+ **kwargs
835
+ )
836
+
837
+ def img2img(
838
+ self,
839
  init_image: Union[torch.FloatTensor, PIL.Image.Image],
840
+ prompt: Union[str, List[str]],
841
+ negative_prompt: Optional[Union[str, List[str]]] = None,
842
  strength: float = 0.8,
843
  num_inference_steps: Optional[int] = 50,
844
  guidance_scale: Optional[float] = 7.5,
 
845
  num_images_per_prompt: Optional[int] = 1,
846
  eta: Optional[float] = 0.0,
847
  generator: Optional[torch.Generator] = None,
 
853
  **kwargs,
854
  ):
855
  r"""
856
+ Function for image-to-image generation.
 
857
  Args:
 
 
858
  init_image (`torch.FloatTensor` or `PIL.Image.Image`):
859
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
860
  process.
861
+ prompt (`str` or `List[str]`):
862
+ The prompt or prompts to guide the image generation.
863
+ negative_prompt (`str` or `List[str]`, *optional*):
864
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
865
+ if `guidance_scale` is less than `1`).
866
  strength (`float`, *optional*, defaults to 0.8):
867
  Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
868
  `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
 
878
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
879
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
880
  usually at the expense of lower image quality.
 
 
 
881
  num_images_per_prompt (`int`, *optional*, defaults to 1):
882
  The number of images to generate per prompt.
883
  eta (`float`, *optional*, defaults to 0.0):
 
900
  callback_steps (`int`, *optional*, defaults to 1):
901
  The frequency at which the `callback` function will be called. If not specified, the callback will be
902
  called at every step.
 
903
  Returns:
904
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
905
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
 
907
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
908
  (nsfw) content, according to the `safety_checker`.
909
  """
910
+ return self.__call__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  prompt=prompt,
912
+ negative_prompt=negative_prompt,
913
+ init_image=init_image,
914
+ num_inference_steps=num_inference_steps,
915
+ guidance_scale=guidance_scale,
916
+ strength=strength,
917
+ num_images_per_prompt=num_images_per_prompt,
918
+ eta=eta,
919
+ generator=generator,
920
  max_embeddings_multiples=max_embeddings_multiples,
921
+ output_type=output_type,
922
+ return_dict=return_dict,
923
+ callback=callback,
924
+ callback_steps=callback_steps,
925
  **kwargs
926
  )
927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
  def inpaint(
929
  self,
 
930
  init_image: Union[torch.FloatTensor, PIL.Image.Image],
931
  mask_image: Union[torch.FloatTensor, PIL.Image.Image],
932
+ prompt: Union[str, List[str]],
933
+ negative_prompt: Optional[Union[str, List[str]]] = None,
934
  strength: float = 0.8,
935
  num_inference_steps: Optional[int] = 50,
936
  guidance_scale: Optional[float] = 7.5,
 
937
  num_images_per_prompt: Optional[int] = 1,
938
  eta: Optional[float] = 0.0,
939
  generator: Optional[torch.Generator] = None,
 
945
  **kwargs,
946
  ):
947
  r"""
948
+ Function for inpaint.
 
949
  Args:
 
 
950
  init_image (`torch.FloatTensor` or `PIL.Image.Image`):
951
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
952
  process. This is the image whose masked region will be inpainted.
 
955
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
956
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
957
  contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
958
+ prompt (`str` or `List[str]`):
959
+ The prompt or prompts to guide the image generation.
960
+ negative_prompt (`str` or `List[str]`, *optional*):
961
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
962
+ if `guidance_scale` is less than `1`).
963
  strength (`float`, *optional*, defaults to 0.8):
964
  Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
965
  is 1, the denoising process will be run on the masked area for the full number of iterations specified
 
974
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
975
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
976
  usually at the expense of lower image quality.
 
 
 
977
  num_images_per_prompt (`int`, *optional*, defaults to 1):
978
  The number of images to generate per prompt.
979
  eta (`float`, *optional*, defaults to 0.0):
 
996
  callback_steps (`int`, *optional*, defaults to 1):
997
  The frequency at which the `callback` function will be called. If not specified, the callback will be
998
  called at every step.
 
999
  Returns:
1000
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1001
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
 
1003
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1004
  (nsfw) content, according to the `safety_checker`.
1005
  """
1006
+ return self.__call__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1007
  prompt=prompt,
1008
+ negative_prompt=negative_prompt,
1009
+ init_image=init_image,
1010
+ mask_image=mask_image,
1011
+ num_inference_steps=num_inference_steps,
1012
+ guidance_scale=guidance_scale,
1013
+ strength=strength,
1014
+ num_images_per_prompt=num_images_per_prompt,
1015
+ eta=eta,
1016
+ generator=generator,
1017
  max_embeddings_multiples=max_embeddings_multiples,
1018
+ output_type=output_type,
1019
+ return_dict=return_dict,
1020
+ callback=callback,
1021
+ callback_steps=callback_steps,
1022
  **kwargs
1023
  )