zhiweili commited on
Commit
e581b1e
1 Parent(s): f01a424
app_haircolor_pix2pix.py CHANGED
@@ -10,8 +10,10 @@ from segment_utils import(
10
  restore_result,
11
  )
12
  from diffusers import (
 
13
  StableDiffusionInstructPix2PixPipeline,
14
  EulerAncestralDiscreteScheduler,
 
15
  )
16
 
17
  from controlnet_aux import (
@@ -30,10 +32,18 @@ DEFAULT_NEGATIVE_PROMPT = "worst quality, normal quality, low quality, low res,
30
 
31
  DEFAULT_CATEGORY = "hair"
32
 
33
- basepipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
 
 
 
 
 
 
34
  BASE_MODEL,
35
  torch_dtype=torch.float16,
36
  use_safetensors=True,
 
 
37
  )
38
 
39
  basepipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(basepipeline.scheduler.config)
@@ -51,10 +61,15 @@ def image_to_image(
51
  guidance_scale: float,
52
  image_guidance_scale: float,
53
  generate_size: int,
 
54
  ):
55
  run_task_time = 0
56
  time_cost_str = ''
57
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
 
 
 
58
 
59
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
60
  generated_image = basepipeline(
@@ -65,6 +80,8 @@ def image_to_image(
65
  guidance_scale=guidance_scale,
66
  image_guidance_scale=image_guidance_scale,
67
  num_inference_steps=num_steps,
 
 
68
  ).images[0]
69
 
70
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
@@ -110,7 +127,6 @@ def create_demo() -> gr.Blocks:
110
  seed = gr.Number(label="Seed", value=8)
111
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
112
  cond_scale1 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond_scale1")
113
- cond_scale2 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond_scale2")
114
  g_btn = gr.Button("Edit Image")
115
 
116
  with gr.Row():
@@ -129,7 +145,7 @@ def create_demo() -> gr.Blocks:
129
  outputs=[origin_area_image, croper],
130
  ).success(
131
  fn=image_to_image,
132
- inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, image_guidance_scale, generate_size],
133
  outputs=[generated_image, generated_cost],
134
  ).success(
135
  fn=restore_result,
 
10
  restore_result,
11
  )
12
  from diffusers import (
13
+ DiffusionPipeline,
14
  StableDiffusionInstructPix2PixPipeline,
15
  EulerAncestralDiscreteScheduler,
16
+ T2IAdapter,
17
  )
18
 
19
  from controlnet_aux import (
 
32
 
33
  DEFAULT_CATEGORY = "hair"
34
 
35
+ adapter = T2IAdapter.from_pretrained(
36
+ "TencentARC/t2iadapter_canny_sd15v2",
37
+ torch_dtype=torch.float16,
38
+ varient="fp16",
39
+ )
40
+
41
+ basepipeline = DiffusionPipeline.from_pretrained(
42
  BASE_MODEL,
43
  torch_dtype=torch.float16,
44
  use_safetensors=True,
45
+ adapter=adapter,
46
+ custom_pipeline="./pipelines/pipeline_sd_adapter_p2p.py",
47
  )
48
 
49
  basepipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(basepipeline.scheduler.config)
 
61
  guidance_scale: float,
62
  image_guidance_scale: float,
63
  generate_size: int,
64
+ cond_scale1: float = 1.2,
65
  ):
66
  run_task_time = 0
67
  time_cost_str = ''
68
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
69
+ canny_image = CannyDetector()(input_image, 384, generate_size)
70
+
71
+ cond_image = canny_image
72
+ cond_scale = cond_scale1
73
 
74
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
75
  generated_image = basepipeline(
 
80
  guidance_scale=guidance_scale,
81
  image_guidance_scale=image_guidance_scale,
82
  num_inference_steps=num_steps,
83
+ adapter_image=cond_image,
84
+ adapter_conditioning_scale=cond_scale,
85
  ).images[0]
86
 
87
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
127
  seed = gr.Number(label="Seed", value=8)
128
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
129
  cond_scale1 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond_scale1")
 
130
  g_btn = gr.Button("Edit Image")
131
 
132
  with gr.Row():
 
145
  outputs=[origin_area_image, croper],
146
  ).success(
147
  fn=image_to_image,
148
+ inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, image_guidance_scale, generate_size, cond_scale1],
149
  outputs=[generated_image, generated_cost],
150
  ).success(
151
  fn=restore_result,
pipelines/pipeline_sd_adapter_p2p.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
22
+
23
+ from diffusers.callbacks import (
24
+ MultiPipelineCallbacks,
25
+ PipelineCallback,
26
+ )
27
+
28
+ from diffusers.image_processor import (
29
+ PipelineImageInput,
30
+ VaeImageProcessor,
31
+ )
32
+
33
+ from diffusers.loaders import (
34
+ IPAdapterMixin,
35
+ StableDiffusionLoraLoaderMixin,
36
+ TextualInversionLoaderMixin,
37
+ )
38
+
39
+ from diffusers.models import (
40
+ AutoencoderKL,
41
+ ImageProjection,
42
+ MultiAdapter,
43
+ T2IAdapter,
44
+ UNet2DConditionModel,
45
+ )
46
+
47
+ from diffusers.schedulers import (
48
+ KarrasDiffusionSchedulers,
49
+ )
50
+
51
+ from diffusers.utils import (
52
+ PIL_INTERPOLATION,
53
+ deprecate,
54
+ logging,
55
+ )
56
+
57
+ from diffusers.utils.torch_utils import (
58
+ randn_tensor,
59
+ )
60
+
61
+ from diffusers.pipelines.pipeline_utils import (
62
+ DiffusionPipeline,
63
+ StableDiffusionMixin,
64
+ )
65
+
66
+ from diffusers.pipelines.stable_diffusion.pipeline_output import (
67
+ StableDiffusionPipelineOutput,
68
+ )
69
+
70
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
71
+ StableDiffusionSafetyChecker,
72
+ )
73
+
74
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
75
+
76
+ def _preprocess_adapter_image(image, height, width):
77
+ if isinstance(image, torch.Tensor):
78
+ return image
79
+ elif isinstance(image, PIL.Image.Image):
80
+ image = [image]
81
+
82
+ if isinstance(image[0], PIL.Image.Image):
83
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
84
+ image = [
85
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
86
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
87
+ image = np.concatenate(image, axis=0)
88
+ image = np.array(image).astype(np.float32) / 255.0
89
+ image = image.transpose(0, 3, 1, 2)
90
+ image = torch.from_numpy(image)
91
+ elif isinstance(image[0], torch.Tensor):
92
+ if image[0].ndim == 3:
93
+ image = torch.stack(image, dim=0)
94
+ elif image[0].ndim == 4:
95
+ image = torch.cat(image, dim=0)
96
+ else:
97
+ raise ValueError(
98
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
99
+ )
100
+ return image
101
+
102
+
103
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
104
+ def preprocess(image):
105
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
106
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
107
+ if isinstance(image, torch.Tensor):
108
+ return image
109
+ elif isinstance(image, PIL.Image.Image):
110
+ image = [image]
111
+
112
+ if isinstance(image[0], PIL.Image.Image):
113
+ w, h = image[0].size
114
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
115
+
116
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
117
+ image = np.concatenate(image, axis=0)
118
+ image = np.array(image).astype(np.float32) / 255.0
119
+ image = image.transpose(0, 3, 1, 2)
120
+ image = 2.0 * image - 1.0
121
+ image = torch.from_numpy(image)
122
+ elif isinstance(image[0], torch.Tensor):
123
+ image = torch.cat(image, dim=0)
124
+ return image
125
+
126
+
127
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
128
+ def retrieve_latents(
129
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
130
+ ):
131
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
132
+ return encoder_output.latent_dist.sample(generator)
133
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
134
+ return encoder_output.latent_dist.mode()
135
+ elif hasattr(encoder_output, "latents"):
136
+ return encoder_output.latents
137
+ else:
138
+ raise AttributeError("Could not access latents of provided encoder_output")
139
+
140
+
141
+ class StableDiffusionInstructPix2PixPipeline(
142
+ DiffusionPipeline,
143
+ StableDiffusionMixin,
144
+ TextualInversionLoaderMixin,
145
+ StableDiffusionLoraLoaderMixin,
146
+ IPAdapterMixin,
147
+ ):
148
+ r"""
149
+ Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
150
+
151
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
152
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
153
+
154
+ The pipeline also inherits the following loading methods:
155
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
156
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
157
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
158
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
159
+
160
+ Args:
161
+ vae ([`AutoencoderKL`]):
162
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
163
+ text_encoder ([`~transformers.CLIPTextModel`]):
164
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
165
+ tokenizer ([`~transformers.CLIPTokenizer`]):
166
+ A `CLIPTokenizer` to tokenize text.
167
+ unet ([`UNet2DConditionModel`]):
168
+ A `UNet2DConditionModel` to denoise the encoded image latents.
169
+ scheduler ([`SchedulerMixin`]):
170
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
171
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
172
+ safety_checker ([`StableDiffusionSafetyChecker`]):
173
+ Classification module that estimates whether generated images could be considered offensive or harmful.
174
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
175
+ about a model's potential harms.
176
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
177
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
178
+ """
179
+
180
+ model_cpu_offload_seq = "text_encoder->unet->vae"
181
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
182
+ _exclude_from_cpu_offload = ["safety_checker"]
183
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"]
184
+
185
+ def __init__(
186
+ self,
187
+ vae: AutoencoderKL,
188
+ text_encoder: CLIPTextModel,
189
+ tokenizer: CLIPTokenizer,
190
+ unet: UNet2DConditionModel,
191
+ adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
192
+ scheduler: KarrasDiffusionSchedulers,
193
+ safety_checker: StableDiffusionSafetyChecker,
194
+ feature_extractor: CLIPImageProcessor,
195
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
196
+ requires_safety_checker: bool = True,
197
+ ):
198
+ super().__init__()
199
+
200
+ if safety_checker is None and requires_safety_checker:
201
+ logger.warning(
202
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
203
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
204
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
205
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
206
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
207
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
208
+ )
209
+
210
+ if safety_checker is not None and feature_extractor is None:
211
+ raise ValueError(
212
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
213
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
214
+ )
215
+
216
+ self.register_modules(
217
+ vae=vae,
218
+ text_encoder=text_encoder,
219
+ tokenizer=tokenizer,
220
+ unet=unet,
221
+ adapter=adapter,
222
+ scheduler=scheduler,
223
+ safety_checker=safety_checker,
224
+ feature_extractor=feature_extractor,
225
+ image_encoder=image_encoder,
226
+ )
227
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
228
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
229
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
230
+
231
+ @torch.no_grad()
232
+ def __call__(
233
+ self,
234
+ prompt: Union[str, List[str]] = None,
235
+ image: PipelineImageInput = None,
236
+ height: Optional[int] = None,
237
+ width: Optional[int] = None,
238
+ adapter_image: PipelineImageInput = None,
239
+ num_inference_steps: int = 100,
240
+ guidance_scale: float = 7.5,
241
+ image_guidance_scale: float = 1.5,
242
+ negative_prompt: Optional[Union[str, List[str]]] = None,
243
+ num_images_per_prompt: Optional[int] = 1,
244
+ eta: float = 0.0,
245
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
246
+ latents: Optional[torch.Tensor] = None,
247
+ prompt_embeds: Optional[torch.Tensor] = None,
248
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
249
+ ip_adapter_image: Optional[PipelineImageInput] = None,
250
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
251
+ output_type: Optional[str] = "pil",
252
+ return_dict: bool = True,
253
+ adapter_conditioning_scale: Union[float, List[float]] = 1.0,
254
+ callback_on_step_end: Optional[
255
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
256
+ ] = None,
257
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
258
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
259
+ **kwargs,
260
+ ):
261
+ r"""
262
+ The call function to the pipeline for generation.
263
+
264
+ Args:
265
+ prompt (`str` or `List[str]`, *optional*):
266
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
267
+ image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
268
+ `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
269
+ image latents as `image`, but if passing latents directly it is not encoded again.
270
+ num_inference_steps (`int`, *optional*, defaults to 100):
271
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
272
+ expense of slower inference.
273
+ guidance_scale (`float`, *optional*, defaults to 7.5):
274
+ A higher guidance scale value encourages the model to generate images closely linked to the text
275
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
276
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
277
+ Push the generated image towards the initial `image`. Image guidance scale is enabled by setting
278
+ `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
279
+ linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
280
+ value of at least `1`.
281
+ negative_prompt (`str` or `List[str]`, *optional*):
282
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
283
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
284
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
285
+ The number of images to generate per prompt.
286
+ eta (`float`, *optional*, defaults to 0.0):
287
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
288
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
289
+ generator (`torch.Generator`, *optional*):
290
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
291
+ generation deterministic.
292
+ latents (`torch.Tensor`, *optional*):
293
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
294
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
295
+ tensor is generated by sampling using the supplied random `generator`.
296
+ prompt_embeds (`torch.Tensor`, *optional*):
297
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
298
+ provided, text embeddings are generated from the `prompt` input argument.
299
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
300
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
301
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
302
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
303
+ Optional image input to work with IP Adapters.
304
+ output_type (`str`, *optional*, defaults to `"pil"`):
305
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
306
+ return_dict (`bool`, *optional*, defaults to `True`):
307
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
308
+ plain tuple.
309
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
310
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
311
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
312
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
313
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
314
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
315
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
316
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
317
+ `._callback_tensor_inputs` attribute of your pipeline class.
318
+ cross_attention_kwargs (`dict`, *optional*):
319
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
320
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
321
+
322
+ Examples:
323
+
324
+ ```py
325
+ >>> import PIL
326
+ >>> import requests
327
+ >>> import torch
328
+ >>> from io import BytesIO
329
+
330
+ >>> from diffusers import StableDiffusionInstructPix2PixPipeline
331
+
332
+
333
+ >>> def download_image(url):
334
+ ... response = requests.get(url)
335
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
336
+
337
+
338
+ >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
339
+
340
+ >>> image = download_image(img_url).resize((512, 512))
341
+
342
+ >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
343
+ ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
344
+ ... )
345
+ >>> pipe = pipe.to("cuda")
346
+
347
+ >>> prompt = "make the mountains snowy"
348
+ >>> image = pipe(prompt=prompt, image=image).images[0]
349
+ ```
350
+
351
+ Returns:
352
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
353
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
354
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
355
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
356
+ "not-safe-for-work" (nsfw) content.
357
+ """
358
+ height, width = self._default_height_width(height, width, adapter_image)
359
+ device = self._execution_device
360
+
361
+ if isinstance(self.adapter, MultiAdapter):
362
+ adapter_input = []
363
+
364
+ for one_image in adapter_image:
365
+ one_image = _preprocess_adapter_image(one_image, height, width)
366
+ one_image = one_image.to(device=device, dtype=self.adapter.dtype)
367
+ adapter_input.append(one_image)
368
+ else:
369
+ adapter_input = _preprocess_adapter_image(adapter_image, height, width)
370
+ adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
371
+
372
+ callback = kwargs.pop("callback", None)
373
+ callback_steps = kwargs.pop("callback_steps", None)
374
+
375
+ if callback is not None:
376
+ deprecate(
377
+ "callback",
378
+ "1.0.0",
379
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
380
+ )
381
+ if callback_steps is not None:
382
+ deprecate(
383
+ "callback_steps",
384
+ "1.0.0",
385
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
386
+ )
387
+
388
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
389
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
390
+
391
+ # 0. Check inputs
392
+ self.check_inputs(
393
+ prompt,
394
+ callback_steps,
395
+ negative_prompt,
396
+ prompt_embeds,
397
+ negative_prompt_embeds,
398
+ ip_adapter_image,
399
+ ip_adapter_image_embeds,
400
+ callback_on_step_end_tensor_inputs,
401
+ )
402
+ self._guidance_scale = guidance_scale
403
+ self._image_guidance_scale = image_guidance_scale
404
+
405
+ device = self._execution_device
406
+
407
+ if image is None:
408
+ raise ValueError("`image` input cannot be undefined.")
409
+
410
+ # 1. Define call parameters
411
+ if prompt is not None and isinstance(prompt, str):
412
+ batch_size = 1
413
+ elif prompt is not None and isinstance(prompt, list):
414
+ batch_size = len(prompt)
415
+ else:
416
+ batch_size = prompt_embeds.shape[0]
417
+
418
+ device = self._execution_device
419
+
420
+ # 2. Encode input prompt
421
+ prompt_embeds = self._encode_prompt(
422
+ prompt,
423
+ device,
424
+ num_images_per_prompt,
425
+ self.do_classifier_free_guidance,
426
+ negative_prompt,
427
+ prompt_embeds=prompt_embeds,
428
+ negative_prompt_embeds=negative_prompt_embeds,
429
+ )
430
+
431
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
432
+ image_embeds = self.prepare_ip_adapter_image_embeds(
433
+ ip_adapter_image,
434
+ ip_adapter_image_embeds,
435
+ device,
436
+ batch_size * num_images_per_prompt,
437
+ self.do_classifier_free_guidance,
438
+ )
439
+ # 3. Preprocess image
440
+ image = self.image_processor.preprocess(image)
441
+
442
+ # 4. set timesteps
443
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
444
+ timesteps = self.scheduler.timesteps
445
+
446
+ # 5. Prepare Image latents
447
+ image_latents = self.prepare_image_latents(
448
+ image,
449
+ batch_size,
450
+ num_images_per_prompt,
451
+ prompt_embeds.dtype,
452
+ device,
453
+ self.do_classifier_free_guidance,
454
+ )
455
+
456
+ height, width = image_latents.shape[-2:]
457
+ height = height * self.vae_scale_factor
458
+ width = width * self.vae_scale_factor
459
+
460
+ # 6. Prepare latent variables
461
+ num_channels_latents = self.vae.config.latent_channels
462
+ latents = self.prepare_latents(
463
+ batch_size * num_images_per_prompt,
464
+ num_channels_latents,
465
+ height,
466
+ width,
467
+ prompt_embeds.dtype,
468
+ device,
469
+ generator,
470
+ latents,
471
+ )
472
+
473
+ # 7. Check that shapes of latents and image match the UNet channels
474
+ num_channels_image = image_latents.shape[1]
475
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
476
+ raise ValueError(
477
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
478
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
479
+ f" `num_channels_image`: {num_channels_image} "
480
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
481
+ " `pipeline.unet` or your `image` input."
482
+ )
483
+
484
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
485
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
486
+
487
+ # 8.1 Add image embeds for IP-Adapter
488
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
489
+
490
+ # 9. Denoising loop
491
+ if isinstance(self.adapter, MultiAdapter):
492
+ adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
493
+ for k, v in enumerate(adapter_state):
494
+ adapter_state[k] = v
495
+ else:
496
+ adapter_state = self.adapter(adapter_input)
497
+ for k, v in enumerate(adapter_state):
498
+ adapter_state[k] = v * adapter_conditioning_scale
499
+ if num_images_per_prompt > 1:
500
+ for k, v in enumerate(adapter_state):
501
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
502
+ if self.do_classifier_free_guidance:
503
+ for k, v in enumerate(adapter_state):
504
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
505
+
506
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
507
+ self._num_timesteps = len(timesteps)
508
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
509
+ for i, t in enumerate(timesteps):
510
+ # Expand the latents if we are doing classifier free guidance.
511
+ # The latents are expanded 3 times because for pix2pix the guidance\
512
+ # is applied for both the text and the input image.
513
+ latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
514
+
515
+ # concat latents, image_latents in the channel dimension
516
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
517
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
518
+
519
+ # predict the noise residual
520
+ noise_pred = self.unet(
521
+ scaled_latent_model_input,
522
+ t,
523
+ encoder_hidden_states=prompt_embeds,
524
+ added_cond_kwargs=added_cond_kwargs,
525
+ down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
526
+ cross_attention_kwargs=cross_attention_kwargs,
527
+ return_dict=False,
528
+ )[0]
529
+
530
+ # perform guidance
531
+ if self.do_classifier_free_guidance:
532
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
533
+ noise_pred = (
534
+ noise_pred_uncond
535
+ + self.guidance_scale * (noise_pred_text - noise_pred_image)
536
+ + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
537
+ )
538
+
539
+ # compute the previous noisy sample x_t -> x_t-1
540
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
541
+
542
+ if callback_on_step_end is not None:
543
+ callback_kwargs = {}
544
+ for k in callback_on_step_end_tensor_inputs:
545
+ callback_kwargs[k] = locals()[k]
546
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
547
+
548
+ latents = callback_outputs.pop("latents", latents)
549
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
550
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
551
+ image_latents = callback_outputs.pop("image_latents", image_latents)
552
+
553
+ # call the callback, if provided
554
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
555
+ progress_bar.update()
556
+ if callback is not None and i % callback_steps == 0:
557
+ step_idx = i // getattr(self.scheduler, "order", 1)
558
+ callback(step_idx, t, latents)
559
+
560
+ if not output_type == "latent":
561
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
562
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
563
+ else:
564
+ image = latents
565
+ has_nsfw_concept = None
566
+
567
+ if has_nsfw_concept is None:
568
+ do_denormalize = [True] * image.shape[0]
569
+ else:
570
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
571
+
572
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
573
+
574
+ # Offload all models
575
+ self.maybe_free_model_hooks()
576
+
577
+ if not return_dict:
578
+ return (image, has_nsfw_concept)
579
+
580
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
581
+
582
+ def _encode_prompt(
583
+ self,
584
+ prompt,
585
+ device,
586
+ num_images_per_prompt,
587
+ do_classifier_free_guidance,
588
+ negative_prompt=None,
589
+ prompt_embeds: Optional[torch.Tensor] = None,
590
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
591
+ ):
592
+ r"""
593
+ Encodes the prompt into text encoder hidden states.
594
+
595
+ Args:
596
+ prompt (`str` or `List[str]`, *optional*):
597
+ prompt to be encoded
598
+ device: (`torch.device`):
599
+ torch device
600
+ num_images_per_prompt (`int`):
601
+ number of images that should be generated per prompt
602
+ do_classifier_free_guidance (`bool`):
603
+ whether to use classifier free guidance or not
604
+ negative_ prompt (`str` or `List[str]`, *optional*):
605
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
606
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
607
+ less than `1`).
608
+ prompt_embeds (`torch.Tensor`, *optional*):
609
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
610
+ provided, text embeddings will be generated from `prompt` input argument.
611
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
612
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
613
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
614
+ argument.
615
+ """
616
+ if prompt is not None and isinstance(prompt, str):
617
+ batch_size = 1
618
+ elif prompt is not None and isinstance(prompt, list):
619
+ batch_size = len(prompt)
620
+ else:
621
+ batch_size = prompt_embeds.shape[0]
622
+
623
+ if prompt_embeds is None:
624
+ # textual inversion: process multi-vector tokens if necessary
625
+ if isinstance(self, TextualInversionLoaderMixin):
626
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
627
+
628
+ text_inputs = self.tokenizer(
629
+ prompt,
630
+ padding="max_length",
631
+ max_length=self.tokenizer.model_max_length,
632
+ truncation=True,
633
+ return_tensors="pt",
634
+ )
635
+ text_input_ids = text_inputs.input_ids
636
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
637
+
638
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
639
+ text_input_ids, untruncated_ids
640
+ ):
641
+ removed_text = self.tokenizer.batch_decode(
642
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
643
+ )
644
+ logger.warning(
645
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
646
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
647
+ )
648
+
649
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
650
+ attention_mask = text_inputs.attention_mask.to(device)
651
+ else:
652
+ attention_mask = None
653
+
654
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
655
+ prompt_embeds = prompt_embeds[0]
656
+
657
+ if self.text_encoder is not None:
658
+ prompt_embeds_dtype = self.text_encoder.dtype
659
+ else:
660
+ prompt_embeds_dtype = self.unet.dtype
661
+
662
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
663
+
664
+ bs_embed, seq_len, _ = prompt_embeds.shape
665
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
666
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
667
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
668
+
669
+ # get unconditional embeddings for classifier free guidance
670
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
671
+ uncond_tokens: List[str]
672
+ if negative_prompt is None:
673
+ uncond_tokens = [""] * batch_size
674
+ elif type(prompt) is not type(negative_prompt):
675
+ raise TypeError(
676
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
677
+ f" {type(prompt)}."
678
+ )
679
+ elif isinstance(negative_prompt, str):
680
+ uncond_tokens = [negative_prompt]
681
+ elif batch_size != len(negative_prompt):
682
+ raise ValueError(
683
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
684
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
685
+ " the batch size of `prompt`."
686
+ )
687
+ else:
688
+ uncond_tokens = negative_prompt
689
+
690
+ # textual inversion: process multi-vector tokens if necessary
691
+ if isinstance(self, TextualInversionLoaderMixin):
692
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
693
+
694
+ max_length = prompt_embeds.shape[1]
695
+ uncond_input = self.tokenizer(
696
+ uncond_tokens,
697
+ padding="max_length",
698
+ max_length=max_length,
699
+ truncation=True,
700
+ return_tensors="pt",
701
+ )
702
+
703
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
704
+ attention_mask = uncond_input.attention_mask.to(device)
705
+ else:
706
+ attention_mask = None
707
+
708
+ negative_prompt_embeds = self.text_encoder(
709
+ uncond_input.input_ids.to(device),
710
+ attention_mask=attention_mask,
711
+ )
712
+ negative_prompt_embeds = negative_prompt_embeds[0]
713
+
714
+ if do_classifier_free_guidance:
715
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
716
+ seq_len = negative_prompt_embeds.shape[1]
717
+
718
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
719
+
720
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
721
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
722
+
723
+ # For classifier free guidance, we need to do two forward passes.
724
+ # Here we concatenate the unconditional and text embeddings into a single batch
725
+ # to avoid doing two forward passes
726
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
727
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
728
+
729
+ return prompt_embeds
730
+
731
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
732
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
733
+ dtype = next(self.image_encoder.parameters()).dtype
734
+
735
+ if not isinstance(image, torch.Tensor):
736
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
737
+
738
+ image = image.to(device=device, dtype=dtype)
739
+ if output_hidden_states:
740
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
741
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
742
+ uncond_image_enc_hidden_states = self.image_encoder(
743
+ torch.zeros_like(image), output_hidden_states=True
744
+ ).hidden_states[-2]
745
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
746
+ num_images_per_prompt, dim=0
747
+ )
748
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
749
+ else:
750
+ image_embeds = self.image_encoder(image).image_embeds
751
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
752
+ uncond_image_embeds = torch.zeros_like(image_embeds)
753
+
754
+ return image_embeds, uncond_image_embeds
755
+
756
+ def prepare_ip_adapter_image_embeds(
757
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
758
+ ):
759
+ if ip_adapter_image_embeds is None:
760
+ if not isinstance(ip_adapter_image, list):
761
+ ip_adapter_image = [ip_adapter_image]
762
+
763
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
764
+ raise ValueError(
765
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
766
+ )
767
+
768
+ image_embeds = []
769
+ for single_ip_adapter_image, image_proj_layer in zip(
770
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
771
+ ):
772
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
773
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
774
+ single_ip_adapter_image, device, 1, output_hidden_state
775
+ )
776
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
777
+ single_negative_image_embeds = torch.stack(
778
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
779
+ )
780
+
781
+ if do_classifier_free_guidance:
782
+ single_image_embeds = torch.cat(
783
+ [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
784
+ )
785
+ single_image_embeds = single_image_embeds.to(device)
786
+
787
+ image_embeds.append(single_image_embeds)
788
+ else:
789
+ repeat_dims = [1]
790
+ image_embeds = []
791
+ for single_image_embeds in ip_adapter_image_embeds:
792
+ if do_classifier_free_guidance:
793
+ (
794
+ single_image_embeds,
795
+ single_negative_image_embeds,
796
+ single_negative_image_embeds,
797
+ ) = single_image_embeds.chunk(3)
798
+ single_image_embeds = single_image_embeds.repeat(
799
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
800
+ )
801
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
802
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
803
+ )
804
+ single_image_embeds = torch.cat(
805
+ [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
806
+ )
807
+ else:
808
+ single_image_embeds = single_image_embeds.repeat(
809
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
810
+ )
811
+ image_embeds.append(single_image_embeds)
812
+
813
+ return image_embeds
814
+
815
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
816
+ def run_safety_checker(self, image, device, dtype):
817
+ if self.safety_checker is None:
818
+ has_nsfw_concept = None
819
+ else:
820
+ if torch.is_tensor(image):
821
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
822
+ else:
823
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
824
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
825
+ image, has_nsfw_concept = self.safety_checker(
826
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
827
+ )
828
+ return image, has_nsfw_concept
829
+
830
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
831
+ def prepare_extra_step_kwargs(self, generator, eta):
832
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
833
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
834
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
835
+ # and should be between [0, 1]
836
+
837
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
838
+ extra_step_kwargs = {}
839
+ if accepts_eta:
840
+ extra_step_kwargs["eta"] = eta
841
+
842
+ # check if the scheduler accepts generator
843
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
844
+ if accepts_generator:
845
+ extra_step_kwargs["generator"] = generator
846
+ return extra_step_kwargs
847
+
848
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
849
+ def decode_latents(self, latents):
850
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
851
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
852
+
853
+ latents = 1 / self.vae.config.scaling_factor * latents
854
+ image = self.vae.decode(latents, return_dict=False)[0]
855
+ image = (image / 2 + 0.5).clamp(0, 1)
856
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
857
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
858
+ return image
859
+
860
+ def check_inputs(
861
+ self,
862
+ prompt,
863
+ callback_steps,
864
+ negative_prompt=None,
865
+ prompt_embeds=None,
866
+ negative_prompt_embeds=None,
867
+ ip_adapter_image=None,
868
+ ip_adapter_image_embeds=None,
869
+ callback_on_step_end_tensor_inputs=None,
870
+ ):
871
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
872
+ raise ValueError(
873
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
874
+ f" {type(callback_steps)}."
875
+ )
876
+
877
+ if callback_on_step_end_tensor_inputs is not None and not all(
878
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
879
+ ):
880
+ raise ValueError(
881
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
882
+ )
883
+
884
+ if prompt is not None and prompt_embeds is not None:
885
+ raise ValueError(
886
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
887
+ " only forward one of the two."
888
+ )
889
+ elif prompt is None and prompt_embeds is None:
890
+ raise ValueError(
891
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
892
+ )
893
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
894
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
895
+
896
+ if negative_prompt is not None and negative_prompt_embeds is not None:
897
+ raise ValueError(
898
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
899
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
900
+ )
901
+
902
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
903
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
904
+ raise ValueError(
905
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
906
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
907
+ f" {negative_prompt_embeds.shape}."
908
+ )
909
+
910
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
911
+ raise ValueError(
912
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
913
+ )
914
+
915
+ if ip_adapter_image_embeds is not None:
916
+ if not isinstance(ip_adapter_image_embeds, list):
917
+ raise ValueError(
918
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
919
+ )
920
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
921
+ raise ValueError(
922
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
923
+ )
924
+
925
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
926
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
927
+ shape = (
928
+ batch_size,
929
+ num_channels_latents,
930
+ int(height) // self.vae_scale_factor,
931
+ int(width) // self.vae_scale_factor,
932
+ )
933
+ if isinstance(generator, list) and len(generator) != batch_size:
934
+ raise ValueError(
935
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
936
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
937
+ )
938
+
939
+ if latents is None:
940
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
941
+ else:
942
+ latents = latents.to(device)
943
+
944
+ # scale the initial noise by the standard deviation required by the scheduler
945
+ latents = latents * self.scheduler.init_noise_sigma
946
+ return latents
947
+
948
+ def _default_height_width(self, height, width, image):
949
+ # NOTE: It is possible that a list of images have different
950
+ # dimensions for each image, so just checking the first image
951
+ # is not _exactly_ correct, but it is simple.
952
+ while isinstance(image, list):
953
+ image = image[0]
954
+
955
+ if height is None:
956
+ if isinstance(image, PIL.Image.Image):
957
+ height = image.height
958
+ elif isinstance(image, torch.Tensor):
959
+ height = image.shape[-2]
960
+
961
+ # round down to nearest multiple of `self.adapter.downscale_factor`
962
+ height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
963
+
964
+ if width is None:
965
+ if isinstance(image, PIL.Image.Image):
966
+ width = image.width
967
+ elif isinstance(image, torch.Tensor):
968
+ width = image.shape[-1]
969
+
970
+ # round down to nearest multiple of `self.adapter.downscale_factor`
971
+ width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
972
+
973
+ return height, width
974
+
975
+
976
+ def prepare_image_latents(
977
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
978
+ ):
979
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
980
+ raise ValueError(
981
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
982
+ )
983
+
984
+ image = image.to(device=device, dtype=dtype)
985
+
986
+ batch_size = batch_size * num_images_per_prompt
987
+
988
+ if image.shape[1] == 4:
989
+ image_latents = image
990
+ else:
991
+ image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
992
+
993
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
994
+ # expand image_latents for batch_size
995
+ deprecation_message = (
996
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
997
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
998
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
999
+ " your script to pass as many initial images as text prompts to suppress this warning."
1000
+ )
1001
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
1002
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
1003
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
1004
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
1005
+ raise ValueError(
1006
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
1007
+ )
1008
+ else:
1009
+ image_latents = torch.cat([image_latents], dim=0)
1010
+
1011
+ if do_classifier_free_guidance:
1012
+ uncond_image_latents = torch.zeros_like(image_latents)
1013
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
1014
+
1015
+ return image_latents
1016
+
1017
+ @property
1018
+ def guidance_scale(self):
1019
+ return self._guidance_scale
1020
+
1021
+ @property
1022
+ def image_guidance_scale(self):
1023
+ return self._image_guidance_scale
1024
+
1025
+ @property
1026
+ def num_timesteps(self):
1027
+ return self._num_timesteps
1028
+
1029
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1030
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1031
+ # corresponds to doing no classifier free guidance.
1032
+ @property
1033
+ def do_classifier_free_guidance(self):
1034
+ return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0