vinesmsuic commited on
Commit
bb2108c
1 Parent(s): b15506b
app.py CHANGED
@@ -14,7 +14,8 @@ from PIL import Image
14
  import torch
15
  import numpy as np
16
 
17
- from black_box_image_edit.instructpix2pix import InstructPix2Pix
 
18
  from prepare_video import crop_and_resize_video
19
  from edit_image import infer_video
20
 
@@ -40,7 +41,7 @@ demo_examples = [
40
  TEMP_DIR = "_demo_temp"
41
 
42
 
43
- image_edit_model = InstructPix2Pix()
44
 
45
  @torch.no_grad()
46
  @spaces.GPU(duration=30)
@@ -315,7 +316,7 @@ with gr.Blocks() as demo:
315
  gr.Markdown("Official 🤗 Gradio demo for [AnyV2V: A Plug-and-Play Framework For Any Video-to-Video Editing Tasks](https://tiger-ai-lab.github.io/AnyV2V/)")
316
 
317
  with gr.Tabs():
318
- with gr.TabItem('AnyV2V(I2VGenXL) + InstructPix2Pix'):
319
  gr.Markdown("# Preprocessing Video Stage")
320
  gr.Markdown("In this demo, AnyV2V only support video with 2 seconds duration and 8 fps. If your video is not in this format, we will preprocess it for you. Click on the Preprocess video button!")
321
  with gr.Row():
@@ -339,7 +340,7 @@ with gr.Blocks() as demo:
339
  pv_longest_to_width = gr.Checkbox(label="Resize Longest Dimension to Width")
340
 
341
  gr.Markdown("# Image Editing Stage")
342
- gr.Markdown("Edit the first frame of the video to your liking! Click on the Edit the first frame button after inputting the editing instruction prompt. This image editing stage is powered by InstructPix2Pix. You can try edit the image multiple times until you are happy with the result! You can also choose to download the first frame of the video and edit it with other software (e.g. Photoshop, GIMP, etc.) or use other image editing models to obtain the edited frame and upload it directly.")
343
  with gr.Row():
344
  with gr.Column():
345
  src_first_frame = gr.Image(label="First Frame", type="filepath", interactive=False)
 
14
  import torch
15
  import numpy as np
16
 
17
+ from black_box_image_edit.cosxl_edit import CosXLEdit
18
+ #from black_box_image_edit.instructpix2pix import InstructPix2Pix
19
  from prepare_video import crop_and_resize_video
20
  from edit_image import infer_video
21
 
 
41
  TEMP_DIR = "_demo_temp"
42
 
43
 
44
+ image_edit_model = CosXLEdit()
45
 
46
  @torch.no_grad()
47
  @spaces.GPU(duration=30)
 
316
  gr.Markdown("Official 🤗 Gradio demo for [AnyV2V: A Plug-and-Play Framework For Any Video-to-Video Editing Tasks](https://tiger-ai-lab.github.io/AnyV2V/)")
317
 
318
  with gr.Tabs():
319
+ with gr.TabItem('AnyV2V(I2VGenXL) + CosXLEdit'):
320
  gr.Markdown("# Preprocessing Video Stage")
321
  gr.Markdown("In this demo, AnyV2V only support video with 2 seconds duration and 8 fps. If your video is not in this format, we will preprocess it for you. Click on the Preprocess video button!")
322
  with gr.Row():
 
340
  pv_longest_to_width = gr.Checkbox(label="Resize Longest Dimension to Width")
341
 
342
  gr.Markdown("# Image Editing Stage")
343
+ gr.Markdown("Edit the first frame of the video to your liking! Click on the Edit the first frame button after inputting the editing instruction prompt. This image editing stage is powered by CosXLEdit. You can try edit the image multiple times until you are happy with the result! You can also choose to download the first frame of the video and edit it with other software (e.g. Photoshop, GIMP, etc.) or use other image editing models to obtain the edited frame and upload it directly.")
344
  with gr.Row():
345
  with gr.Column():
346
  src_first_frame = gr.Image(label="First Frame", type="filepath", interactive=False)
black_box_image_edit/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
  from .instructpix2pix import InstructPix2Pix, MagicBrush
 
2
 
3
  from typing import Union, Optional, Tuple
4
  import numpy as np
 
1
  from .instructpix2pix import InstructPix2Pix, MagicBrush
2
+ from .cosxl_edit import CosXLEdit
3
 
4
  from typing import Union, Optional, Tuple
5
  import numpy as np
black_box_image_edit/cosxl/custom_pipeline.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Harutatsu Akiyama and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import PIL.Image
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
21
+
22
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
23
+ from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
24
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
+ from diffusers.models.attention_processor import (
26
+ AttnProcessor2_0,
27
+ FusedAttnProcessor2_0,
28
+ LoRAAttnProcessor2_0,
29
+ LoRAXFormersAttnProcessor,
30
+ XFormersAttnProcessor,
31
+ )
32
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ is_invisible_watermark_available,
38
+ is_torch_xla_available,
39
+ logging,
40
+ replace_example_docstring,
41
+ scale_lora_layers,
42
+ )
43
+ from diffusers.utils.torch_utils import randn_tensor
44
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
45
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
46
+
47
+
48
+ if is_invisible_watermark_available():
49
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
50
+
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
58
+
59
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
+
61
+ EXAMPLE_DOC_STRING = """
62
+ Examples:
63
+ ```py
64
+ >>> import torch
65
+ >>> from diffusers import StableDiffusionXLInstructPix2PixPipeline
66
+ >>> from diffusers.utils import load_image
67
+ >>> resolution = 768
68
+ >>> image = load_image(
69
+ ... "https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
70
+ ... ).resize((resolution, resolution))
71
+ >>> edit_instruction = "Turn sky into a cloudy one"
72
+ >>> pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
73
+ ... "diffusers/sdxl-instructpix2pix-768", torch_dtype=torch.float16
74
+ ... ).to("cuda")
75
+ >>> edited_image = pipe(
76
+ ... prompt=edit_instruction,
77
+ ... image=image,
78
+ ... height=resolution,
79
+ ... width=resolution,
80
+ ... guidance_scale=3.0,
81
+ ... image_guidance_scale=1.5,
82
+ ... num_inference_steps=30,
83
+ ... ).images[0]
84
+ >>> edited_image
85
+ ```
86
+ """
87
+
88
+
89
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
90
+ def retrieve_latents(
91
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
92
+ ):
93
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
94
+ return encoder_output.latent_dist.sample(generator)
95
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
96
+ return encoder_output.latent_dist.mode()
97
+ elif hasattr(encoder_output, "latents"):
98
+ return encoder_output.latents
99
+ else:
100
+ raise AttributeError("Could not access latents of provided encoder_output")
101
+
102
+
103
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
104
+ """
105
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
106
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
107
+ """
108
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
109
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
110
+ # rescale the results from guidance (fixes overexposure)
111
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
112
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
113
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
114
+ return noise_cfg
115
+
116
+
117
+ class CosStableDiffusionXLInstructPix2PixPipeline(
118
+ DiffusionPipeline,
119
+ StableDiffusionMixin,
120
+ TextualInversionLoaderMixin,
121
+ FromSingleFileMixin,
122
+ StableDiffusionXLLoraLoaderMixin,
123
+ ):
124
+ r"""
125
+ Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL.
126
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
127
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
128
+ The pipeline also inherits the following loading methods:
129
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
130
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
131
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
132
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
133
+ Args:
134
+ vae ([`AutoencoderKL`]):
135
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
136
+ text_encoder ([`CLIPTextModel`]):
137
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
138
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
139
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
140
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
141
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
142
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
143
+ specifically the
144
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
145
+ variant.
146
+ tokenizer (`CLIPTokenizer`):
147
+ Tokenizer of class
148
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
149
+ tokenizer_2 (`CLIPTokenizer`):
150
+ Second Tokenizer of class
151
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
152
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
153
+ scheduler ([`SchedulerMixin`]):
154
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
155
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
156
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
157
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
158
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
159
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
160
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
161
+ `stabilityai/stable-diffusion-xl-base-1-0`.
162
+ add_watermarker (`bool`, *optional*):
163
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
164
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
165
+ watermarker will be used.
166
+ """
167
+
168
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
169
+ _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
170
+
171
+ def __init__(
172
+ self,
173
+ vae: AutoencoderKL,
174
+ text_encoder: CLIPTextModel,
175
+ text_encoder_2: CLIPTextModelWithProjection,
176
+ tokenizer: CLIPTokenizer,
177
+ tokenizer_2: CLIPTokenizer,
178
+ unet: UNet2DConditionModel,
179
+ scheduler: KarrasDiffusionSchedulers,
180
+ force_zeros_for_empty_prompt: bool = True,
181
+ add_watermarker: Optional[bool] = None,
182
+ ):
183
+ super().__init__()
184
+
185
+ self.register_modules(
186
+ vae=vae,
187
+ text_encoder=text_encoder,
188
+ text_encoder_2=text_encoder_2,
189
+ tokenizer=tokenizer,
190
+ tokenizer_2=tokenizer_2,
191
+ unet=unet,
192
+ scheduler=scheduler,
193
+ )
194
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
195
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
196
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
197
+ self.default_sample_size = self.unet.config.sample_size
198
+
199
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
200
+
201
+ if add_watermarker:
202
+ self.watermark = StableDiffusionXLWatermarker()
203
+ else:
204
+ self.watermark = None
205
+
206
+ def encode_prompt(
207
+ self,
208
+ prompt: str,
209
+ prompt_2: Optional[str] = None,
210
+ device: Optional[torch.device] = None,
211
+ num_images_per_prompt: int = 1,
212
+ do_classifier_free_guidance: bool = True,
213
+ negative_prompt: Optional[str] = None,
214
+ negative_prompt_2: Optional[str] = None,
215
+ prompt_embeds: Optional[torch.FloatTensor] = None,
216
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
217
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
218
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
219
+ lora_scale: Optional[float] = None,
220
+ ):
221
+ r"""
222
+ Encodes the prompt into text encoder hidden states.
223
+ Args:
224
+ prompt (`str` or `List[str]`, *optional*):
225
+ prompt to be encoded
226
+ prompt_2 (`str` or `List[str]`, *optional*):
227
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
228
+ used in both text-encoders
229
+ device: (`torch.device`):
230
+ torch device
231
+ num_images_per_prompt (`int`):
232
+ number of images that should be generated per prompt
233
+ do_classifier_free_guidance (`bool`):
234
+ whether to use classifier free guidance or not
235
+ negative_prompt (`str` or `List[str]`, *optional*):
236
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
237
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
238
+ less than `1`).
239
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
240
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
241
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
242
+ prompt_embeds (`torch.FloatTensor`, *optional*):
243
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
244
+ provided, text embeddings will be generated from `prompt` input argument.
245
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
246
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
247
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
248
+ argument.
249
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
250
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
251
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
252
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
253
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
254
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
255
+ input argument.
256
+ lora_scale (`float`, *optional*):
257
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
258
+ """
259
+ device = device or self._execution_device
260
+
261
+ # set lora scale so that monkey patched LoRA
262
+ # function of text encoder can correctly access it
263
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
264
+ self._lora_scale = lora_scale
265
+
266
+ # dynamically adjust the LoRA scale
267
+ if self.text_encoder is not None:
268
+ if not USE_PEFT_BACKEND:
269
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
270
+ else:
271
+ scale_lora_layers(self.text_encoder, lora_scale)
272
+
273
+ if self.text_encoder_2 is not None:
274
+ if not USE_PEFT_BACKEND:
275
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
276
+ else:
277
+ scale_lora_layers(self.text_encoder_2, lora_scale)
278
+
279
+ if prompt is not None and isinstance(prompt, str):
280
+ batch_size = 1
281
+ elif prompt is not None and isinstance(prompt, list):
282
+ batch_size = len(prompt)
283
+ else:
284
+ batch_size = prompt_embeds.shape[0]
285
+
286
+ # Define tokenizers and text encoders
287
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
288
+ text_encoders = (
289
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
290
+ )
291
+
292
+ if prompt_embeds is None:
293
+ prompt_2 = prompt_2 or prompt
294
+ # textual inversion: process multi-vector tokens if necessary
295
+ prompt_embeds_list = []
296
+ prompts = [prompt, prompt_2]
297
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
298
+ if isinstance(self, TextualInversionLoaderMixin):
299
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
300
+
301
+ text_inputs = tokenizer(
302
+ prompt,
303
+ padding="max_length",
304
+ max_length=tokenizer.model_max_length,
305
+ truncation=True,
306
+ return_tensors="pt",
307
+ )
308
+
309
+ text_input_ids = text_inputs.input_ids
310
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
311
+
312
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
313
+ text_input_ids, untruncated_ids
314
+ ):
315
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
316
+ logger.warning(
317
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
318
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
319
+ )
320
+
321
+ prompt_embeds = text_encoder(
322
+ text_input_ids.to(device),
323
+ output_hidden_states=True,
324
+ )
325
+
326
+ # We are only ALWAYS interested in the pooled output of the final text encoder
327
+ pooled_prompt_embeds = prompt_embeds[0]
328
+ prompt_embeds = prompt_embeds.hidden_states[-2]
329
+
330
+ prompt_embeds_list.append(prompt_embeds)
331
+
332
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
333
+
334
+ # get unconditional embeddings for classifier free guidance
335
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
336
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
337
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
338
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
339
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
340
+ negative_prompt = negative_prompt or ""
341
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
342
+
343
+ uncond_tokens: List[str]
344
+ if prompt is not None and type(prompt) is not type(negative_prompt):
345
+ raise TypeError(
346
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
347
+ f" {type(prompt)}."
348
+ )
349
+ elif isinstance(negative_prompt, str):
350
+ uncond_tokens = [negative_prompt, negative_prompt_2]
351
+ elif batch_size != len(negative_prompt):
352
+ raise ValueError(
353
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
354
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
355
+ " the batch size of `prompt`."
356
+ )
357
+ else:
358
+ uncond_tokens = [negative_prompt, negative_prompt_2]
359
+
360
+ negative_prompt_embeds_list = []
361
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
362
+ if isinstance(self, TextualInversionLoaderMixin):
363
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
364
+
365
+ max_length = prompt_embeds.shape[1]
366
+ uncond_input = tokenizer(
367
+ negative_prompt,
368
+ padding="max_length",
369
+ max_length=max_length,
370
+ truncation=True,
371
+ return_tensors="pt",
372
+ )
373
+
374
+ negative_prompt_embeds = text_encoder(
375
+ uncond_input.input_ids.to(device),
376
+ output_hidden_states=True,
377
+ )
378
+ # We are only ALWAYS interested in the pooled output of the final text encoder
379
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
380
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
381
+
382
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
383
+
384
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
385
+
386
+ prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype
387
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
388
+ bs_embed, seq_len, _ = prompt_embeds.shape
389
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
390
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
391
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
392
+
393
+ if do_classifier_free_guidance:
394
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
395
+ seq_len = negative_prompt_embeds.shape[1]
396
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
397
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
398
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
399
+
400
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
401
+ bs_embed * num_images_per_prompt, -1
402
+ )
403
+ if do_classifier_free_guidance:
404
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
405
+ bs_embed * num_images_per_prompt, -1
406
+ )
407
+
408
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
409
+
410
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
411
+ def prepare_extra_step_kwargs(self, generator, eta):
412
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
413
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
414
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
415
+ # and should be between [0, 1]
416
+
417
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
418
+ extra_step_kwargs = {}
419
+ if accepts_eta:
420
+ extra_step_kwargs["eta"] = eta
421
+
422
+ # check if the scheduler accepts generator
423
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
424
+ if accepts_generator:
425
+ extra_step_kwargs["generator"] = generator
426
+ return extra_step_kwargs
427
+
428
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
429
+ def check_inputs(
430
+ self,
431
+ prompt,
432
+ callback_steps,
433
+ negative_prompt=None,
434
+ prompt_embeds=None,
435
+ negative_prompt_embeds=None,
436
+ callback_on_step_end_tensor_inputs=None,
437
+ ):
438
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
439
+ raise ValueError(
440
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
441
+ f" {type(callback_steps)}."
442
+ )
443
+
444
+ if callback_on_step_end_tensor_inputs is not None and not all(
445
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
446
+ ):
447
+ raise ValueError(
448
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
449
+ )
450
+
451
+ if prompt is not None and prompt_embeds is not None:
452
+ raise ValueError(
453
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
454
+ " only forward one of the two."
455
+ )
456
+ elif prompt is None and prompt_embeds is None:
457
+ raise ValueError(
458
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
459
+ )
460
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
461
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
462
+
463
+ if negative_prompt is not None and negative_prompt_embeds is not None:
464
+ raise ValueError(
465
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
466
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
467
+ )
468
+
469
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
470
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
471
+ raise ValueError(
472
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
473
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
474
+ f" {negative_prompt_embeds.shape}."
475
+ )
476
+
477
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
478
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
479
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
480
+ if isinstance(generator, list) and len(generator) != batch_size:
481
+ raise ValueError(
482
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
483
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
484
+ )
485
+
486
+ if latents is None:
487
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
488
+ else:
489
+ latents = latents.to(device)
490
+
491
+ # scale the initial noise by the standard deviation required by the scheduler
492
+ latents = latents * self.scheduler.init_noise_sigma
493
+ return latents
494
+
495
+ def prepare_image_latents(
496
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
497
+ ):
498
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
499
+ raise ValueError(
500
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
501
+ )
502
+
503
+ image = image.to(device=device, dtype=dtype)
504
+
505
+ batch_size = batch_size * num_images_per_prompt
506
+
507
+ if image.shape[1] == 4:
508
+ image_latents = image
509
+ else:
510
+ # make sure the VAE is in float32 mode, as it overflows in float16
511
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
512
+ if needs_upcasting:
513
+ self.upcast_vae()
514
+ image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
515
+
516
+ image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
517
+
518
+ # cast back to fp16 if needed
519
+ if needs_upcasting:
520
+ self.vae.to(dtype=torch.float16)
521
+
522
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
523
+ # expand image_latents for batch_size
524
+ deprecation_message = (
525
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
526
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
527
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
528
+ " your script to pass as many initial images as text prompts to suppress this warning."
529
+ )
530
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
531
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
532
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
533
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
534
+ raise ValueError(
535
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
536
+ )
537
+ else:
538
+ image_latents = torch.cat([image_latents], dim=0)
539
+
540
+ if do_classifier_free_guidance:
541
+ uncond_image_latents = torch.zeros_like(image_latents)
542
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
543
+
544
+ if image_latents.dtype != self.vae.dtype:
545
+ image_latents = image_latents.to(dtype=self.vae.dtype)
546
+
547
+ return image_latents
548
+
549
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
550
+ def _get_add_time_ids(
551
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
552
+ ):
553
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
554
+
555
+ passed_add_embed_dim = (
556
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
557
+ )
558
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
559
+
560
+ if expected_add_embed_dim != passed_add_embed_dim:
561
+ raise ValueError(
562
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
563
+ )
564
+
565
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
566
+ return add_time_ids
567
+
568
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
569
+ def upcast_vae(self):
570
+ dtype = self.vae.dtype
571
+ self.vae.to(dtype=torch.float32)
572
+ use_torch_2_0_or_xformers = isinstance(
573
+ self.vae.decoder.mid_block.attentions[0].processor,
574
+ (
575
+ AttnProcessor2_0,
576
+ XFormersAttnProcessor,
577
+ LoRAXFormersAttnProcessor,
578
+ LoRAAttnProcessor2_0,
579
+ FusedAttnProcessor2_0,
580
+ ),
581
+ )
582
+ # if xformers or torch_2_0 is used attention block does not need
583
+ # to be in float32 which can save lots of memory
584
+ if use_torch_2_0_or_xformers:
585
+ self.vae.post_quant_conv.to(dtype)
586
+ self.vae.decoder.conv_in.to(dtype)
587
+ self.vae.decoder.mid_block.to(dtype)
588
+
589
+ @torch.no_grad()
590
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
591
+ def __call__(
592
+ self,
593
+ prompt: Union[str, List[str]] = None,
594
+ prompt_2: Optional[Union[str, List[str]]] = None,
595
+ image: PipelineImageInput = None,
596
+ height: Optional[int] = None,
597
+ width: Optional[int] = None,
598
+ num_inference_steps: int = 100,
599
+ denoising_end: Optional[float] = None,
600
+ guidance_scale: float = 5.0,
601
+ image_guidance_scale: float = 1.5,
602
+ negative_prompt: Optional[Union[str, List[str]]] = None,
603
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
604
+ num_images_per_prompt: Optional[int] = 1,
605
+ eta: float = 0.0,
606
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
607
+ latents: Optional[torch.FloatTensor] = None,
608
+ prompt_embeds: Optional[torch.FloatTensor] = None,
609
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
610
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
611
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
612
+ output_type: Optional[str] = "pil",
613
+ return_dict: bool = True,
614
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
615
+ callback_steps: int = 1,
616
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
617
+ guidance_rescale: float = 0.0,
618
+ original_size: Tuple[int, int] = None,
619
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
620
+ target_size: Tuple[int, int] = None,
621
+ ):
622
+ r"""
623
+ Function invoked when calling the pipeline for generation.
624
+ Args:
625
+ prompt (`str` or `List[str]`, *optional*):
626
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
627
+ instead.
628
+ prompt_2 (`str` or `List[str]`, *optional*):
629
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
630
+ used in both text-encoders
631
+ image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
632
+ The image(s) to modify with the pipeline.
633
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
634
+ The height in pixels of the generated image.
635
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
636
+ The width in pixels of the generated image.
637
+ num_inference_steps (`int`, *optional*, defaults to 50):
638
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
639
+ expense of slower inference.
640
+ denoising_end (`float`, *optional*):
641
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
642
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
643
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
644
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
645
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
646
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
647
+ guidance_scale (`float`, *optional*, defaults to 5.0):
648
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
649
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
650
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
651
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
652
+ usually at the expense of lower image quality.
653
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
654
+ Image guidance scale is to push the generated image towards the initial image `image`. Image guidance
655
+ scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to
656
+ generate images that are closely linked to the source image `image`, usually at the expense of lower
657
+ image quality. This pipeline requires a value of at least `1`.
658
+ negative_prompt (`str` or `List[str]`, *optional*):
659
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
660
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
661
+ less than `1`).
662
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
663
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
664
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
665
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
666
+ The number of images to generate per prompt.
667
+ eta (`float`, *optional*, defaults to 0.0):
668
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
669
+ [`schedulers.DDIMScheduler`], will be ignored for others.
670
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
671
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
672
+ to make generation deterministic.
673
+ latents (`torch.FloatTensor`, *optional*):
674
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
675
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
676
+ tensor will ge generated by sampling using the supplied random `generator`.
677
+ prompt_embeds (`torch.FloatTensor`, *optional*):
678
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
679
+ provided, text embeddings will be generated from `prompt` input argument.
680
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
681
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
682
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
683
+ argument.
684
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
685
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
686
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
687
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
688
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
689
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
690
+ input argument.
691
+ output_type (`str`, *optional*, defaults to `"pil"`):
692
+ The output format of the generate image. Choose between
693
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
694
+ return_dict (`bool`, *optional*, defaults to `True`):
695
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
696
+ plain tuple.
697
+ callback (`Callable`, *optional*):
698
+ A function that will be called every `callback_steps` steps during inference. The function will be
699
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
700
+ callback_steps (`int`, *optional*, defaults to 1):
701
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
702
+ called at every step.
703
+ cross_attention_kwargs (`dict`, *optional*):
704
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
705
+ `self.processor` in
706
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
707
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
708
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
709
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
710
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
711
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
712
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
713
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
714
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
715
+ explained in section 2.2 of
716
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
717
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
718
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
719
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
720
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
721
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
722
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
723
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
724
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
725
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
726
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
727
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
728
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
729
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
730
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
731
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
732
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
733
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
734
+ Examples:
735
+ Returns:
736
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
737
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
738
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
739
+ """
740
+ # 0. Default height and width to unet
741
+ height = height or self.default_sample_size * self.vae_scale_factor
742
+ width = width or self.default_sample_size * self.vae_scale_factor
743
+
744
+ original_size = original_size or (height, width)
745
+ target_size = target_size or (height, width)
746
+
747
+ # 1. Check inputs. Raise error if not correct
748
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
749
+
750
+ if image is None:
751
+ raise ValueError("`image` input cannot be undefined.")
752
+
753
+ # 2. Define call parameters
754
+ if prompt is not None and isinstance(prompt, str):
755
+ batch_size = 1
756
+ elif prompt is not None and isinstance(prompt, list):
757
+ batch_size = len(prompt)
758
+ else:
759
+ batch_size = prompt_embeds.shape[0]
760
+
761
+ device = self._execution_device
762
+
763
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
764
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
765
+ # corresponds to doing no classifier free guidance.
766
+ do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0
767
+
768
+ # 3. Encode input prompt
769
+ text_encoder_lora_scale = (
770
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
771
+ )
772
+ (
773
+ prompt_embeds,
774
+ negative_prompt_embeds,
775
+ pooled_prompt_embeds,
776
+ negative_pooled_prompt_embeds,
777
+ ) = self.encode_prompt(
778
+ prompt=prompt,
779
+ prompt_2=prompt_2,
780
+ device=device,
781
+ num_images_per_prompt=num_images_per_prompt,
782
+ do_classifier_free_guidance=do_classifier_free_guidance,
783
+ negative_prompt=negative_prompt,
784
+ negative_prompt_2=negative_prompt_2,
785
+ prompt_embeds=prompt_embeds,
786
+ negative_prompt_embeds=negative_prompt_embeds,
787
+ pooled_prompt_embeds=pooled_prompt_embeds,
788
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
789
+ lora_scale=text_encoder_lora_scale,
790
+ )
791
+
792
+ # 4. Preprocess image
793
+ image = self.image_processor.preprocess(image, height=height, width=width).to(device)
794
+
795
+ # 5. Prepare timesteps
796
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
797
+ timesteps = self.scheduler.timesteps
798
+
799
+ # 6. Prepare Image latents
800
+ image_latents = self.prepare_image_latents(
801
+ image,
802
+ batch_size,
803
+ num_images_per_prompt,
804
+ prompt_embeds.dtype,
805
+ device,
806
+ do_classifier_free_guidance,
807
+ )
808
+
809
+ image_latents = image_latents * self.vae.config.scaling_factor
810
+
811
+ # 7. Prepare latent variables
812
+ num_channels_latents = self.vae.config.latent_channels
813
+ latents = self.prepare_latents(
814
+ batch_size * num_images_per_prompt,
815
+ num_channels_latents,
816
+ height,
817
+ width,
818
+ prompt_embeds.dtype,
819
+ device,
820
+ generator,
821
+ latents,
822
+ )
823
+
824
+ # 8. Check that shapes of latents and image match the UNet channels
825
+ num_channels_image = image_latents.shape[1]
826
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
827
+ raise ValueError(
828
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
829
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
830
+ f" `num_channels_image`: {num_channels_image} "
831
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
832
+ " `pipeline.unet` or your `image` input."
833
+ )
834
+
835
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
836
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
837
+
838
+ # 10. Prepare added time ids & embeddings
839
+ add_text_embeds = pooled_prompt_embeds
840
+ if self.text_encoder_2 is None:
841
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
842
+ else:
843
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
844
+
845
+ add_time_ids = self._get_add_time_ids(
846
+ original_size,
847
+ crops_coords_top_left,
848
+ target_size,
849
+ dtype=prompt_embeds.dtype,
850
+ text_encoder_projection_dim=text_encoder_projection_dim,
851
+ )
852
+
853
+ if do_classifier_free_guidance:
854
+ # The extra concat similar to how it's done in SD InstructPix2Pix.
855
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0)
856
+ add_text_embeds = torch.cat(
857
+ [add_text_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0
858
+ )
859
+ add_time_ids = torch.cat([add_time_ids, add_time_ids, add_time_ids], dim=0)
860
+
861
+ prompt_embeds = prompt_embeds.to(device)
862
+ add_text_embeds = add_text_embeds.to(device)
863
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
864
+
865
+ # 11. Denoising loop
866
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
867
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
868
+ discrete_timestep_cutoff = int(
869
+ round(
870
+ self.scheduler.config.num_train_timesteps
871
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
872
+ )
873
+ )
874
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
875
+ timesteps = timesteps[:num_inference_steps]
876
+
877
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
878
+ for i, t in enumerate(timesteps):
879
+ # Expand the latents if we are doing classifier free guidance.
880
+ # The latents are expanded 3 times because for pix2pix the guidance
881
+ # is applied for both the text and the input image.
882
+ latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
883
+
884
+ # concat latents, image_latents in the channel dimension
885
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
886
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
887
+
888
+ # predict the noise residual
889
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
890
+ noise_pred = self.unet(
891
+ scaled_latent_model_input,
892
+ t,
893
+ encoder_hidden_states=prompt_embeds,
894
+ cross_attention_kwargs=cross_attention_kwargs,
895
+ added_cond_kwargs=added_cond_kwargs,
896
+ return_dict=False,
897
+ )[0]
898
+
899
+ # perform guidance
900
+ if do_classifier_free_guidance:
901
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
902
+ noise_pred = (
903
+ noise_pred_uncond
904
+ + guidance_scale * (noise_pred_text - noise_pred_image)
905
+ + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
906
+ )
907
+
908
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
909
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
910
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
911
+
912
+ # compute the previous noisy sample x_t -> x_t-1
913
+ latents_dtype = latents.dtype
914
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
915
+ if latents.dtype != latents_dtype:
916
+ if torch.backends.mps.is_available():
917
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
918
+ latents = latents.to(latents_dtype)
919
+
920
+ # call the callback, if provided
921
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
922
+ progress_bar.update()
923
+ if callback is not None and i % callback_steps == 0:
924
+ step_idx = i // getattr(self.scheduler, "order", 1)
925
+ callback(step_idx, t, latents)
926
+
927
+ if XLA_AVAILABLE:
928
+ xm.mark_step()
929
+
930
+ if not output_type == "latent":
931
+ # make sure the VAE is in float32 mode, as it overflows in float16
932
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
933
+
934
+ if needs_upcasting:
935
+ self.upcast_vae()
936
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
937
+ elif latents.dtype != self.vae.dtype:
938
+ if torch.backends.mps.is_available():
939
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
940
+ self.vae = self.vae.to(latents.dtype)
941
+
942
+ # unscale/denormalize the latents
943
+ # denormalize with the mean and std if available and not None
944
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
945
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
946
+ if has_latents_mean and has_latents_std:
947
+ latents_mean = (
948
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
949
+ )
950
+ latents_std = (
951
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
952
+ )
953
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
954
+ else:
955
+ latents = latents / self.vae.config.scaling_factor
956
+
957
+ image = self.vae.decode(latents, return_dict=False)[0]
958
+
959
+ # cast back to fp16 if needed
960
+ if needs_upcasting:
961
+ self.vae.to(dtype=torch.float16)
962
+ else:
963
+ return StableDiffusionXLPipelineOutput(images=latents)
964
+
965
+ # apply watermark if available
966
+ if self.watermark is not None:
967
+ image = self.watermark.apply_watermark(image)
968
+
969
+ image = self.image_processor.postprocess(image, output_type=output_type)
970
+
971
+ # Offload all models
972
+ self.maybe_free_model_hooks()
973
+
974
+ if not return_dict:
975
+ return (image,)
976
+
977
+ return StableDiffusionXLPipelineOutput(images=image)
black_box_image_edit/cosxl/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ def set_timesteps_patched(self, num_inference_steps: int, device = None):
6
+ self.num_inference_steps = num_inference_steps
7
+
8
+ ramp = np.linspace(0, 1, self.num_inference_steps)
9
+ sigmas = torch.linspace(math.log(self.config.sigma_min), math.log(self.config.sigma_max), len(ramp)).exp().flip(0)
10
+
11
+ sigmas = (sigmas).to(dtype=torch.float32, device=device)
12
+ self.timesteps = self.precondition_noise(sigmas)
13
+
14
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
15
+ self._step_index = None
16
+ self._begin_index = None
17
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
black_box_image_edit/cosxl_edit.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from huggingface_hub import hf_hub_download
3
+ import torch
4
+ import PIL
5
+
6
+ class CosXLEdit():
7
+ """
8
+ Edit Cos Stable Diffusion XL 1.0 Base is tuned to use a Cosine-Continuous EDM VPred schedule, and then upgraded to perform instructed image editing.
9
+ Reference: https://huggingface.co/stabilityai/cosxl
10
+ """
11
+ def __init__(self, device="cuda"):
12
+ """
13
+ Attributes:
14
+ pipe (CosStableDiffusionXLInstructPix2PixPipeline): The InstructPix2Pix pipeline for image transformation.
15
+
16
+ Args:
17
+ device (str, optional): Device on which the pipeline runs. Defaults to "cuda".
18
+ """
19
+ from diffusers import EDMEulerScheduler
20
+ from .cosxl.custom_pipeline import CosStableDiffusionXLInstructPix2PixPipeline
21
+ from .cosxl.utils import set_timesteps_patched
22
+
23
+ EDMEulerScheduler.set_timesteps = set_timesteps_patched
24
+ edit_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors")
25
+ self.pipe = CosStableDiffusionXLInstructPix2PixPipeline.from_single_file(
26
+ edit_file, num_in_channels=8
27
+ )
28
+ self.pipe.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
29
+ self.pipe.to(device)
30
+
31
+ def infer_one_image(self, src_image: PIL.Image.Image = None, src_prompt: str = None, target_prompt: str = None, instruct_prompt: str = None, seed: int = 42, negative_prompt=""):
32
+ """
33
+ Modifies the source image based on the provided instruction prompt.
34
+
35
+ Args:
36
+ src_image (PIL.Image.Image): Source image in RGB format.
37
+ instruct_prompt (str): Caption for editing the image.
38
+ seed (int, optional): Seed for random generator. Defaults to 42.
39
+
40
+ Returns:
41
+ PIL.Image.Image: The transformed image.
42
+ """
43
+ src_image = src_image.convert('RGB') # force it to RGB format
44
+ generator = torch.manual_seed(seed)
45
+
46
+ resolution = 1024
47
+ preprocessed_image = src_image.resize((resolution, resolution))
48
+ image = self.pipe(prompt=instruct_prompt,
49
+ image=preprocessed_image,
50
+ height=resolution,
51
+ width=resolution,
52
+ negative_prompt=negative_prompt,
53
+ guidance_scale=7,
54
+ num_inference_steps=20,
55
+ generator=generator).images[0]
56
+ image = image.resize((src_image.width, src_image.height))
57
+
58
+ return image
black_box_image_edit/instantstyle.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
2
+ import cv2
3
+ import torch
4
+ import PIL
5
+ import numpy as np
6
+ import os
7
+
8
+ class InstantStyle():
9
+ def __init__(self,
10
+ device="cuda",
11
+ weight="stabilityai/stable-diffusion-xl-base-1.0",
12
+ control_weight="diffusers/controlnet-canny-sdxl-1.0",
13
+ custom_sdxl_models_folder="sdxl_models"):
14
+ from .ip_adapter import IPAdapterXL
15
+
16
+ controlnet = ControlNetModel.from_pretrained(control_weight,
17
+ use_safetensors=False,
18
+ torch_dtype=torch.float16).to(device)
19
+ # load SDXL pipeline
20
+ sdxl_control_pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
21
+ weight,
22
+ controlnet=controlnet,
23
+ torch_dtype=torch.float16,
24
+ add_watermarker=False,
25
+ )
26
+ sdxl_control_pipe.enable_vae_tiling()
27
+ self.ip_model = IPAdapterXL(sdxl_control_pipe,
28
+ os.path.join(custom_sdxl_models_folder, "image_encoder"),
29
+ os.path.join(custom_sdxl_models_folder, "ip-adapter_sdxl.bin"),
30
+ device,
31
+ target_blocks=["up_blocks.0.attentions.1"])
32
+
33
+
34
+ def infer_one_image(self, src_image: PIL.Image.Image = None,
35
+ style_image: PIL.Image.Image = None,
36
+ prompt: str = "masterpiece, best quality, high quality",
37
+ seed: int = 42,
38
+ negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry"):
39
+
40
+ src_image = src_image.convert('RGB') # force it to RGB format
41
+ style_image = style_image.convert('RGB') # force it to RGB format
42
+
43
+ def pil_to_cv2(image_pil):
44
+ image_np = np.array(image_pil)
45
+ image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
46
+
47
+ return image_cv2
48
+ # control image
49
+ input_image = pil_to_cv2(src_image)
50
+ detected_map = cv2.Canny(input_image, 50, 200)
51
+ canny_map = PIL.Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
52
+
53
+ # generate image
54
+ if prompt is None:
55
+ prompt = "masterpiece, best quality, high quality"
56
+ image = self.ip_model.generate(pil_image=style_image,
57
+ prompt=prompt,
58
+ negative_prompt=negative_prompt,
59
+ scale=1.0,
60
+ guidance_scale=5,
61
+ num_samples=1,
62
+ num_inference_steps=30,
63
+ seed=seed,
64
+ image=canny_map,
65
+ controlnet_conditioning_scale=0.6,
66
+ )[0]
67
+ return image
black_box_image_edit/ip_adapter/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
2
+
3
+ __all__ = [
4
+ "IPAdapter",
5
+ "IPAdapterPlus",
6
+ "IPAdapterPlusXL",
7
+ "IPAdapterXL",
8
+ "IPAdapterFull",
9
+ ]
black_box_image_edit/ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ ):
17
+ super().__init__()
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+ )
41
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42
+
43
+ if attn.group_norm is not None:
44
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
+
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ query = attn.head_to_batch_dim(query)
57
+ key = attn.head_to_batch_dim(key)
58
+ value = attn.head_to_batch_dim(value)
59
+
60
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
+ hidden_states = torch.bmm(attention_probs, value)
62
+ hidden_states = attn.batch_to_head_dim(hidden_states)
63
+
64
+ # linear proj
65
+ hidden_states = attn.to_out[0](hidden_states)
66
+ # dropout
67
+ hidden_states = attn.to_out[1](hidden_states)
68
+
69
+ if input_ndim == 4:
70
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71
+
72
+ if attn.residual_connection:
73
+ hidden_states = hidden_states + residual
74
+
75
+ hidden_states = hidden_states / attn.rescale_output_factor
76
+
77
+ return hidden_states
78
+
79
+
80
+ class IPAttnProcessor(nn.Module):
81
+ r"""
82
+ Attention processor for IP-Adapater.
83
+ Args:
84
+ hidden_size (`int`):
85
+ The hidden size of the attention layer.
86
+ cross_attention_dim (`int`):
87
+ The number of channels in the `encoder_hidden_states`.
88
+ scale (`float`, defaults to 1.0):
89
+ the weight scale of image prompt.
90
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91
+ The context length of the image features.
92
+ """
93
+
94
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
95
+ super().__init__()
96
+
97
+ self.hidden_size = hidden_size
98
+ self.cross_attention_dim = cross_attention_dim
99
+ self.scale = scale
100
+ self.num_tokens = num_tokens
101
+ self.skip = skip
102
+
103
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105
+
106
+ def __call__(
107
+ self,
108
+ attn,
109
+ hidden_states,
110
+ encoder_hidden_states=None,
111
+ attention_mask=None,
112
+ temb=None,
113
+ ):
114
+ residual = hidden_states
115
+
116
+ if attn.spatial_norm is not None:
117
+ hidden_states = attn.spatial_norm(hidden_states, temb)
118
+
119
+ input_ndim = hidden_states.ndim
120
+
121
+ if input_ndim == 4:
122
+ batch_size, channel, height, width = hidden_states.shape
123
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
124
+
125
+ batch_size, sequence_length, _ = (
126
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
127
+ )
128
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
+
130
+ if attn.group_norm is not None:
131
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = attn.to_q(hidden_states)
134
+
135
+ if encoder_hidden_states is None:
136
+ encoder_hidden_states = hidden_states
137
+ else:
138
+ # get encoder_hidden_states, ip_hidden_states
139
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
140
+ encoder_hidden_states, ip_hidden_states = (
141
+ encoder_hidden_states[:, :end_pos, :],
142
+ encoder_hidden_states[:, end_pos:, :],
143
+ )
144
+ if attn.norm_cross:
145
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
146
+
147
+ key = attn.to_k(encoder_hidden_states)
148
+ value = attn.to_v(encoder_hidden_states)
149
+
150
+ query = attn.head_to_batch_dim(query)
151
+ key = attn.head_to_batch_dim(key)
152
+ value = attn.head_to_batch_dim(value)
153
+
154
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
155
+ hidden_states = torch.bmm(attention_probs, value)
156
+ hidden_states = attn.batch_to_head_dim(hidden_states)
157
+
158
+ if not self.skip:
159
+ # for ip-adapter
160
+ ip_key = self.to_k_ip(ip_hidden_states)
161
+ ip_value = self.to_v_ip(ip_hidden_states)
162
+
163
+ ip_key = attn.head_to_batch_dim(ip_key)
164
+ ip_value = attn.head_to_batch_dim(ip_value)
165
+
166
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
167
+ self.attn_map = ip_attention_probs
168
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
169
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
170
+
171
+ hidden_states = hidden_states + self.scale * ip_hidden_states
172
+
173
+ # linear proj
174
+ hidden_states = attn.to_out[0](hidden_states)
175
+ # dropout
176
+ hidden_states = attn.to_out[1](hidden_states)
177
+
178
+ if input_ndim == 4:
179
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
180
+
181
+ if attn.residual_connection:
182
+ hidden_states = hidden_states + residual
183
+
184
+ hidden_states = hidden_states / attn.rescale_output_factor
185
+
186
+ return hidden_states
187
+
188
+
189
+ class AttnProcessor2_0(torch.nn.Module):
190
+ r"""
191
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_size=None,
197
+ cross_attention_dim=None,
198
+ ):
199
+ super().__init__()
200
+ if not hasattr(F, "scaled_dot_product_attention"):
201
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
202
+
203
+ def __call__(
204
+ self,
205
+ attn,
206
+ hidden_states,
207
+ encoder_hidden_states=None,
208
+ attention_mask=None,
209
+ temb=None,
210
+ ):
211
+ residual = hidden_states
212
+
213
+ if attn.spatial_norm is not None:
214
+ hidden_states = attn.spatial_norm(hidden_states, temb)
215
+
216
+ input_ndim = hidden_states.ndim
217
+
218
+ if input_ndim == 4:
219
+ batch_size, channel, height, width = hidden_states.shape
220
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
221
+
222
+ batch_size, sequence_length, _ = (
223
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
224
+ )
225
+
226
+ if attention_mask is not None:
227
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
228
+ # scaled_dot_product_attention expects attention_mask shape to be
229
+ # (batch, heads, source_length, target_length)
230
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
231
+
232
+ if attn.group_norm is not None:
233
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
234
+
235
+ query = attn.to_q(hidden_states)
236
+
237
+ if encoder_hidden_states is None:
238
+ encoder_hidden_states = hidden_states
239
+ elif attn.norm_cross:
240
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
241
+
242
+ key = attn.to_k(encoder_hidden_states)
243
+ value = attn.to_v(encoder_hidden_states)
244
+
245
+ inner_dim = key.shape[-1]
246
+ head_dim = inner_dim // attn.heads
247
+
248
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+
250
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
+
253
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
254
+ # TODO: add support for attn.scale when we move to Torch 2.1
255
+ hidden_states = F.scaled_dot_product_attention(
256
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
257
+ )
258
+
259
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
260
+ hidden_states = hidden_states.to(query.dtype)
261
+
262
+ # linear proj
263
+ hidden_states = attn.to_out[0](hidden_states)
264
+ # dropout
265
+ hidden_states = attn.to_out[1](hidden_states)
266
+
267
+ if input_ndim == 4:
268
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
269
+
270
+ if attn.residual_connection:
271
+ hidden_states = hidden_states + residual
272
+
273
+ hidden_states = hidden_states / attn.rescale_output_factor
274
+
275
+ return hidden_states
276
+
277
+
278
+ class IPAttnProcessor2_0(torch.nn.Module):
279
+ r"""
280
+ Attention processor for IP-Adapater for PyTorch 2.0.
281
+ Args:
282
+ hidden_size (`int`):
283
+ The hidden size of the attention layer.
284
+ cross_attention_dim (`int`):
285
+ The number of channels in the `encoder_hidden_states`.
286
+ scale (`float`, defaults to 1.0):
287
+ the weight scale of image prompt.
288
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
289
+ The context length of the image features.
290
+ """
291
+
292
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
293
+ super().__init__()
294
+
295
+ if not hasattr(F, "scaled_dot_product_attention"):
296
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
297
+
298
+ self.hidden_size = hidden_size
299
+ self.cross_attention_dim = cross_attention_dim
300
+ self.scale = scale
301
+ self.num_tokens = num_tokens
302
+ self.skip = skip
303
+
304
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
305
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
306
+
307
+ def __call__(
308
+ self,
309
+ attn,
310
+ hidden_states,
311
+ encoder_hidden_states=None,
312
+ attention_mask=None,
313
+ temb=None,
314
+ ):
315
+ residual = hidden_states
316
+
317
+ if attn.spatial_norm is not None:
318
+ hidden_states = attn.spatial_norm(hidden_states, temb)
319
+
320
+ input_ndim = hidden_states.ndim
321
+
322
+ if input_ndim == 4:
323
+ batch_size, channel, height, width = hidden_states.shape
324
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
325
+
326
+ batch_size, sequence_length, _ = (
327
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
328
+ )
329
+
330
+ if attention_mask is not None:
331
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
332
+ # scaled_dot_product_attention expects attention_mask shape to be
333
+ # (batch, heads, source_length, target_length)
334
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
335
+
336
+ if attn.group_norm is not None:
337
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
338
+
339
+ query = attn.to_q(hidden_states)
340
+
341
+ if encoder_hidden_states is None:
342
+ encoder_hidden_states = hidden_states
343
+ else:
344
+ # get encoder_hidden_states, ip_hidden_states
345
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
346
+ encoder_hidden_states, ip_hidden_states = (
347
+ encoder_hidden_states[:, :end_pos, :],
348
+ encoder_hidden_states[:, end_pos:, :],
349
+ )
350
+ if attn.norm_cross:
351
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
352
+
353
+ key = attn.to_k(encoder_hidden_states)
354
+ value = attn.to_v(encoder_hidden_states)
355
+
356
+ inner_dim = key.shape[-1]
357
+ head_dim = inner_dim // attn.heads
358
+
359
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
360
+
361
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
362
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
363
+
364
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
365
+ # TODO: add support for attn.scale when we move to Torch 2.1
366
+ hidden_states = F.scaled_dot_product_attention(
367
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
368
+ )
369
+
370
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
371
+ hidden_states = hidden_states.to(query.dtype)
372
+
373
+ if not self.skip:
374
+ # for ip-adapter
375
+ ip_key = self.to_k_ip(ip_hidden_states)
376
+ ip_value = self.to_v_ip(ip_hidden_states)
377
+
378
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
379
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
+
381
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
382
+ # TODO: add support for attn.scale when we move to Torch 2.1
383
+ ip_hidden_states = F.scaled_dot_product_attention(
384
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
385
+ )
386
+ with torch.no_grad():
387
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
388
+ #print(self.attn_map.shape)
389
+
390
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
391
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
392
+
393
+ hidden_states = hidden_states + self.scale * ip_hidden_states
394
+
395
+ # linear proj
396
+ hidden_states = attn.to_out[0](hidden_states)
397
+ # dropout
398
+ hidden_states = attn.to_out[1](hidden_states)
399
+
400
+ if input_ndim == 4:
401
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
402
+
403
+ if attn.residual_connection:
404
+ hidden_states = hidden_states + residual
405
+
406
+ hidden_states = hidden_states / attn.rescale_output_factor
407
+
408
+ return hidden_states
409
+
410
+
411
+ ## for controlnet
412
+ class CNAttnProcessor:
413
+ r"""
414
+ Default processor for performing attention-related computations.
415
+ """
416
+
417
+ def __init__(self, num_tokens=4):
418
+ self.num_tokens = num_tokens
419
+
420
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
421
+ residual = hidden_states
422
+
423
+ if attn.spatial_norm is not None:
424
+ hidden_states = attn.spatial_norm(hidden_states, temb)
425
+
426
+ input_ndim = hidden_states.ndim
427
+
428
+ if input_ndim == 4:
429
+ batch_size, channel, height, width = hidden_states.shape
430
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
431
+
432
+ batch_size, sequence_length, _ = (
433
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
434
+ )
435
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
436
+
437
+ if attn.group_norm is not None:
438
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
439
+
440
+ query = attn.to_q(hidden_states)
441
+
442
+ if encoder_hidden_states is None:
443
+ encoder_hidden_states = hidden_states
444
+ else:
445
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
446
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
447
+ if attn.norm_cross:
448
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
449
+
450
+ key = attn.to_k(encoder_hidden_states)
451
+ value = attn.to_v(encoder_hidden_states)
452
+
453
+ query = attn.head_to_batch_dim(query)
454
+ key = attn.head_to_batch_dim(key)
455
+ value = attn.head_to_batch_dim(value)
456
+
457
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
458
+ hidden_states = torch.bmm(attention_probs, value)
459
+ hidden_states = attn.batch_to_head_dim(hidden_states)
460
+
461
+ # linear proj
462
+ hidden_states = attn.to_out[0](hidden_states)
463
+ # dropout
464
+ hidden_states = attn.to_out[1](hidden_states)
465
+
466
+ if input_ndim == 4:
467
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
468
+
469
+ if attn.residual_connection:
470
+ hidden_states = hidden_states + residual
471
+
472
+ hidden_states = hidden_states / attn.rescale_output_factor
473
+
474
+ return hidden_states
475
+
476
+
477
+ class CNAttnProcessor2_0:
478
+ r"""
479
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
480
+ """
481
+
482
+ def __init__(self, num_tokens=4):
483
+ if not hasattr(F, "scaled_dot_product_attention"):
484
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
485
+ self.num_tokens = num_tokens
486
+
487
+ def __call__(
488
+ self,
489
+ attn,
490
+ hidden_states,
491
+ encoder_hidden_states=None,
492
+ attention_mask=None,
493
+ temb=None,
494
+ ):
495
+ residual = hidden_states
496
+
497
+ if attn.spatial_norm is not None:
498
+ hidden_states = attn.spatial_norm(hidden_states, temb)
499
+
500
+ input_ndim = hidden_states.ndim
501
+
502
+ if input_ndim == 4:
503
+ batch_size, channel, height, width = hidden_states.shape
504
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
508
+ )
509
+
510
+ if attention_mask is not None:
511
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
512
+ # scaled_dot_product_attention expects attention_mask shape to be
513
+ # (batch, heads, source_length, target_length)
514
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
515
+
516
+ if attn.group_norm is not None:
517
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
518
+
519
+ query = attn.to_q(hidden_states)
520
+
521
+ if encoder_hidden_states is None:
522
+ encoder_hidden_states = hidden_states
523
+ else:
524
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
525
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
526
+ if attn.norm_cross:
527
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ inner_dim = key.shape[-1]
533
+ head_dim = inner_dim // attn.heads
534
+
535
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
536
+
537
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
538
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
539
+
540
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
541
+ # TODO: add support for attn.scale when we move to Torch 2.1
542
+ hidden_states = F.scaled_dot_product_attention(
543
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
544
+ )
545
+
546
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
547
+ hidden_states = hidden_states.to(query.dtype)
548
+
549
+ # linear proj
550
+ hidden_states = attn.to_out[0](hidden_states)
551
+ # dropout
552
+ hidden_states = attn.to_out[1](hidden_states)
553
+
554
+ if input_ndim == 4:
555
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
556
+
557
+ if attn.residual_connection:
558
+ hidden_states = hidden_states + residual
559
+
560
+ hidden_states = hidden_states / attn.rescale_output_factor
561
+
562
+ return hidden_states
black_box_image_edit/ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from .utils import is_torch2_available, get_generator
12
+
13
+ if is_torch2_available():
14
+ from .attention_processor import (
15
+ AttnProcessor2_0 as AttnProcessor,
16
+ )
17
+ from .attention_processor import (
18
+ CNAttnProcessor2_0 as CNAttnProcessor,
19
+ )
20
+ from .attention_processor import (
21
+ IPAttnProcessor2_0 as IPAttnProcessor,
22
+ )
23
+ else:
24
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
25
+ from .resampler import Resampler
26
+
27
+
28
+ class ImageProjModel(torch.nn.Module):
29
+ """Projection Model"""
30
+
31
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
32
+ super().__init__()
33
+
34
+ self.generator = None
35
+ self.cross_attention_dim = cross_attention_dim
36
+ self.clip_extra_context_tokens = clip_extra_context_tokens
37
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
38
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
39
+
40
+ def forward(self, image_embeds):
41
+ embeds = image_embeds
42
+ clip_extra_context_tokens = self.proj(embeds).reshape(
43
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
44
+ )
45
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
46
+ return clip_extra_context_tokens
47
+
48
+
49
+ class MLPProjModel(torch.nn.Module):
50
+ """SD model with image prompt"""
51
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
52
+ super().__init__()
53
+
54
+ self.proj = torch.nn.Sequential(
55
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
56
+ torch.nn.GELU(),
57
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
58
+ torch.nn.LayerNorm(cross_attention_dim)
59
+ )
60
+
61
+ def forward(self, image_embeds):
62
+ clip_extra_context_tokens = self.proj(image_embeds)
63
+ return clip_extra_context_tokens
64
+
65
+
66
+ class IPAdapter:
67
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
68
+ self.device = device
69
+ self.image_encoder_path = image_encoder_path
70
+ self.ip_ckpt = ip_ckpt
71
+ self.num_tokens = num_tokens
72
+ self.target_blocks = target_blocks
73
+
74
+ self.pipe = sd_pipe.to(self.device)
75
+ self.set_ip_adapter()
76
+
77
+ # load image encoder
78
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
79
+ self.device, dtype=torch.float16
80
+ )
81
+ self.clip_image_processor = CLIPImageProcessor()
82
+ # image proj model
83
+ self.image_proj_model = self.init_proj()
84
+
85
+ self.load_ip_adapter()
86
+
87
+ def init_proj(self):
88
+ image_proj_model = ImageProjModel(
89
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
90
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
91
+ clip_extra_context_tokens=self.num_tokens,
92
+ ).to(self.device, dtype=torch.float16)
93
+ return image_proj_model
94
+
95
+ def set_ip_adapter(self):
96
+ unet = self.pipe.unet
97
+ attn_procs = {}
98
+ for name in unet.attn_processors.keys():
99
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
100
+ if name.startswith("mid_block"):
101
+ hidden_size = unet.config.block_out_channels[-1]
102
+ elif name.startswith("up_blocks"):
103
+ block_id = int(name[len("up_blocks.")])
104
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
105
+ elif name.startswith("down_blocks"):
106
+ block_id = int(name[len("down_blocks.")])
107
+ hidden_size = unet.config.block_out_channels[block_id]
108
+ if cross_attention_dim is None:
109
+ attn_procs[name] = AttnProcessor()
110
+ else:
111
+ selected = False
112
+ for block_name in self.target_blocks:
113
+ if block_name in name:
114
+ selected = True
115
+ break
116
+ if selected:
117
+ attn_procs[name] = IPAttnProcessor(
118
+ hidden_size=hidden_size,
119
+ cross_attention_dim=cross_attention_dim,
120
+ scale=1.0,
121
+ num_tokens=self.num_tokens,
122
+ ).to(self.device, dtype=torch.float16)
123
+ else:
124
+ attn_procs[name] = IPAttnProcessor(
125
+ hidden_size=hidden_size,
126
+ cross_attention_dim=cross_attention_dim,
127
+ scale=1.0,
128
+ num_tokens=self.num_tokens,
129
+ skip=True
130
+ ).to(self.device, dtype=torch.float16)
131
+ unet.set_attn_processor(attn_procs)
132
+ if hasattr(self.pipe, "controlnet"):
133
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
134
+ for controlnet in self.pipe.controlnet.nets:
135
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
136
+ else:
137
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
138
+
139
+ def load_ip_adapter(self):
140
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
141
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
142
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
143
+ for key in f.keys():
144
+ if key.startswith("image_proj."):
145
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
146
+ elif key.startswith("ip_adapter."):
147
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
148
+ else:
149
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
150
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
151
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
152
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
153
+
154
+ @torch.inference_mode()
155
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
156
+ if pil_image is not None:
157
+ if isinstance(pil_image, Image.Image):
158
+ pil_image = [pil_image]
159
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
160
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
161
+ else:
162
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
163
+
164
+ if content_prompt_embeds is not None:
165
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
166
+
167
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
168
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
169
+ return image_prompt_embeds, uncond_image_prompt_embeds
170
+
171
+ def set_scale(self, scale):
172
+ for attn_processor in self.pipe.unet.attn_processors.values():
173
+ if isinstance(attn_processor, IPAttnProcessor):
174
+ attn_processor.scale = scale
175
+
176
+ def generate(
177
+ self,
178
+ pil_image=None,
179
+ clip_image_embeds=None,
180
+ prompt=None,
181
+ negative_prompt=None,
182
+ scale=1.0,
183
+ num_samples=4,
184
+ seed=None,
185
+ guidance_scale=7.5,
186
+ num_inference_steps=30,
187
+ neg_content_emb=None,
188
+ **kwargs,
189
+ ):
190
+ self.set_scale(scale)
191
+
192
+ if pil_image is not None:
193
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
194
+ else:
195
+ num_prompts = clip_image_embeds.size(0)
196
+
197
+ if prompt is None:
198
+ prompt = "best quality, high quality"
199
+ if negative_prompt is None:
200
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
201
+
202
+ if not isinstance(prompt, List):
203
+ prompt = [prompt] * num_prompts
204
+ if not isinstance(negative_prompt, List):
205
+ negative_prompt = [negative_prompt] * num_prompts
206
+
207
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
208
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
209
+ )
210
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
211
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
212
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
213
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
214
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
215
+
216
+ with torch.inference_mode():
217
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
218
+ prompt,
219
+ device=self.device,
220
+ num_images_per_prompt=num_samples,
221
+ do_classifier_free_guidance=True,
222
+ negative_prompt=negative_prompt,
223
+ )
224
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
225
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
226
+
227
+ generator = get_generator(seed, self.device)
228
+
229
+ images = self.pipe(
230
+ prompt_embeds=prompt_embeds,
231
+ negative_prompt_embeds=negative_prompt_embeds,
232
+ guidance_scale=guidance_scale,
233
+ num_inference_steps=num_inference_steps,
234
+ generator=generator,
235
+ **kwargs,
236
+ ).images
237
+
238
+ return images
239
+
240
+
241
+ class IPAdapterXL(IPAdapter):
242
+ """SDXL"""
243
+
244
+ def generate(
245
+ self,
246
+ pil_image,
247
+ prompt=None,
248
+ negative_prompt=None,
249
+ scale=1.0,
250
+ num_samples=4,
251
+ seed=None,
252
+ num_inference_steps=30,
253
+ neg_content_emb=None,
254
+ neg_content_prompt=None,
255
+ neg_content_scale=1.0,
256
+ **kwargs,
257
+ ):
258
+ self.set_scale(scale)
259
+
260
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
261
+
262
+ if prompt is None:
263
+ prompt = "best quality, high quality"
264
+ if negative_prompt is None:
265
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
266
+
267
+ if not isinstance(prompt, List):
268
+ prompt = [prompt] * num_prompts
269
+ if not isinstance(negative_prompt, List):
270
+ negative_prompt = [negative_prompt] * num_prompts
271
+
272
+ if neg_content_emb is None:
273
+ if neg_content_prompt is not None:
274
+ with torch.inference_mode():
275
+ (
276
+ prompt_embeds_, # torch.Size([1, 77, 2048])
277
+ negative_prompt_embeds_,
278
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
279
+ negative_pooled_prompt_embeds_,
280
+ ) = self.pipe.encode_prompt(
281
+ neg_content_prompt,
282
+ num_images_per_prompt=num_samples,
283
+ do_classifier_free_guidance=True,
284
+ negative_prompt=negative_prompt,
285
+ )
286
+ pooled_prompt_embeds_ *= neg_content_scale
287
+ else:
288
+ pooled_prompt_embeds_ = neg_content_emb
289
+ else:
290
+ pooled_prompt_embeds_ = None
291
+
292
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_)
293
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
294
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
295
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
296
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
297
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
298
+
299
+ with torch.inference_mode():
300
+ (
301
+ prompt_embeds,
302
+ negative_prompt_embeds,
303
+ pooled_prompt_embeds,
304
+ negative_pooled_prompt_embeds,
305
+ ) = self.pipe.encode_prompt(
306
+ prompt,
307
+ num_images_per_prompt=num_samples,
308
+ do_classifier_free_guidance=True,
309
+ negative_prompt=negative_prompt,
310
+ )
311
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
312
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
313
+
314
+ self.generator = get_generator(seed, self.device)
315
+
316
+ images = self.pipe(
317
+ prompt_embeds=prompt_embeds,
318
+ negative_prompt_embeds=negative_prompt_embeds,
319
+ pooled_prompt_embeds=pooled_prompt_embeds,
320
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
321
+ num_inference_steps=num_inference_steps,
322
+ generator=self.generator,
323
+ **kwargs,
324
+ ).images
325
+
326
+ return images
327
+
328
+
329
+ class IPAdapterPlus(IPAdapter):
330
+ """IP-Adapter with fine-grained features"""
331
+
332
+ def init_proj(self):
333
+ image_proj_model = Resampler(
334
+ dim=self.pipe.unet.config.cross_attention_dim,
335
+ depth=4,
336
+ dim_head=64,
337
+ heads=12,
338
+ num_queries=self.num_tokens,
339
+ embedding_dim=self.image_encoder.config.hidden_size,
340
+ output_dim=self.pipe.unet.config.cross_attention_dim,
341
+ ff_mult=4,
342
+ ).to(self.device, dtype=torch.float16)
343
+ return image_proj_model
344
+
345
+ @torch.inference_mode()
346
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
347
+ if isinstance(pil_image, Image.Image):
348
+ pil_image = [pil_image]
349
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
350
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
351
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
352
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
353
+ uncond_clip_image_embeds = self.image_encoder(
354
+ torch.zeros_like(clip_image), output_hidden_states=True
355
+ ).hidden_states[-2]
356
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
357
+ return image_prompt_embeds, uncond_image_prompt_embeds
358
+
359
+
360
+ class IPAdapterFull(IPAdapterPlus):
361
+ """IP-Adapter with full features"""
362
+
363
+ def init_proj(self):
364
+ image_proj_model = MLPProjModel(
365
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
366
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
367
+ ).to(self.device, dtype=torch.float16)
368
+ return image_proj_model
369
+
370
+
371
+ class IPAdapterPlusXL(IPAdapter):
372
+ """SDXL"""
373
+
374
+ def init_proj(self):
375
+ image_proj_model = Resampler(
376
+ dim=1280,
377
+ depth=4,
378
+ dim_head=64,
379
+ heads=20,
380
+ num_queries=self.num_tokens,
381
+ embedding_dim=self.image_encoder.config.hidden_size,
382
+ output_dim=self.pipe.unet.config.cross_attention_dim,
383
+ ff_mult=4,
384
+ ).to(self.device, dtype=torch.float16)
385
+ return image_proj_model
386
+
387
+ @torch.inference_mode()
388
+ def get_image_embeds(self, pil_image):
389
+ if isinstance(pil_image, Image.Image):
390
+ pil_image = [pil_image]
391
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
392
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
393
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
394
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
395
+ uncond_clip_image_embeds = self.image_encoder(
396
+ torch.zeros_like(clip_image), output_hidden_states=True
397
+ ).hidden_states[-2]
398
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
399
+ return image_prompt_embeds, uncond_image_prompt_embeds
400
+
401
+ def generate(
402
+ self,
403
+ pil_image,
404
+ prompt=None,
405
+ negative_prompt=None,
406
+ scale=1.0,
407
+ num_samples=4,
408
+ seed=None,
409
+ num_inference_steps=30,
410
+ **kwargs,
411
+ ):
412
+ self.set_scale(scale)
413
+
414
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
415
+
416
+ if prompt is None:
417
+ prompt = "best quality, high quality"
418
+ if negative_prompt is None:
419
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
420
+
421
+ if not isinstance(prompt, List):
422
+ prompt = [prompt] * num_prompts
423
+ if not isinstance(negative_prompt, List):
424
+ negative_prompt = [negative_prompt] * num_prompts
425
+
426
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
427
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
428
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
429
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
430
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
431
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
432
+
433
+ with torch.inference_mode():
434
+ (
435
+ prompt_embeds,
436
+ negative_prompt_embeds,
437
+ pooled_prompt_embeds,
438
+ negative_pooled_prompt_embeds,
439
+ ) = self.pipe.encode_prompt(
440
+ prompt,
441
+ num_images_per_prompt=num_samples,
442
+ do_classifier_free_guidance=True,
443
+ negative_prompt=negative_prompt,
444
+ )
445
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
446
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
447
+
448
+ generator = get_generator(seed, self.device)
449
+
450
+ images = self.pipe(
451
+ prompt_embeds=prompt_embeds,
452
+ negative_prompt_embeds=negative_prompt_embeds,
453
+ pooled_prompt_embeds=pooled_prompt_embeds,
454
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
455
+ num_inference_steps=num_inference_steps,
456
+ generator=generator,
457
+ **kwargs,
458
+ ).images
459
+
460
+ return images
black_box_image_edit/ip_adapter/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
black_box_image_edit/ip_adapter/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ attn_maps = {}
7
+ def hook_fn(name):
8
+ def forward_hook(module, input, output):
9
+ if hasattr(module.processor, "attn_map"):
10
+ attn_maps[name] = module.processor.attn_map
11
+ del module.processor.attn_map
12
+
13
+ return forward_hook
14
+
15
+ def register_cross_attention_hook(unet):
16
+ for name, module in unet.named_modules():
17
+ if name.split('.')[-1].startswith('attn2'):
18
+ module.register_forward_hook(hook_fn(name))
19
+
20
+ return unet
21
+
22
+ def upscale(attn_map, target_size):
23
+ attn_map = torch.mean(attn_map, dim=0)
24
+ attn_map = attn_map.permute(1,0)
25
+ temp_size = None
26
+
27
+ for i in range(0,5):
28
+ scale = 2 ** i
29
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31
+ break
32
+
33
+ assert temp_size is not None, "temp_size cannot is None"
34
+
35
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
+
37
+ attn_map = F.interpolate(
38
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
39
+ size=target_size,
40
+ mode='bilinear',
41
+ align_corners=False
42
+ )[0]
43
+
44
+ attn_map = torch.softmax(attn_map, dim=0)
45
+ return attn_map
46
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47
+
48
+ idx = 0 if instance_or_negative else 1
49
+ net_attn_maps = []
50
+
51
+ for name, attn_map in attn_maps.items():
52
+ attn_map = attn_map.cpu() if detach else attn_map
53
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54
+ attn_map = upscale(attn_map, image_size)
55
+ net_attn_maps.append(attn_map)
56
+
57
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58
+
59
+ return net_attn_maps
60
+
61
+ def attnmaps2images(net_attn_maps):
62
+
63
+ #total_attn_scores = 0
64
+ images = []
65
+
66
+ for attn_map in net_attn_maps:
67
+ attn_map = attn_map.cpu().numpy()
68
+ #total_attn_scores += attn_map.mean().item()
69
+
70
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
72
+ #print("norm: ", normalized_attn_map.shape)
73
+ image = Image.fromarray(normalized_attn_map)
74
+
75
+ #image = fix_save_attn_map(attn_map)
76
+ images.append(image)
77
+
78
+ #print(total_attn_scores)
79
+ return images
80
+ def is_torch2_available():
81
+ return hasattr(F, "scaled_dot_product_attention")
82
+
83
+ def get_generator(seed, device):
84
+
85
+ if seed is not None:
86
+ if isinstance(seed, list):
87
+ generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
88
+ else:
89
+ generator = torch.Generator(device).manual_seed(seed)
90
+ else:
91
+ generator = None
92
+
93
+ return generator
black_box_image_edit/utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from moviepy.editor import VideoFileClip
3
+ import random
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ def crop_and_resize_video(input_video_path, output_folder, clip_duration=None, width=None, height=None, start_time=None, end_time=None, n_frames=16, center_crop=False, x_offset=0, y_offset=0, longest_to_width=False): # Load the video file
8
+ video = VideoFileClip(input_video_path)
9
+
10
+ # Calculate start and end times for cropping
11
+ if clip_duration is not None:
12
+ if start_time is not None:
13
+ start_time = float(start_time)
14
+ end_time = start_time + clip_duration
15
+ elif end_time is not None:
16
+ end_time = float(end_time)
17
+ start_time = end_time - clip_duration
18
+ else:
19
+ # Default to random cropping if neither start nor end time is specified
20
+ video_duration = video.duration
21
+ if video_duration <= clip_duration:
22
+ print(f"Skipping {input_video_path}: duration is less than or equal to the clip duration.")
23
+ return
24
+ max_start_time = video_duration - clip_duration
25
+ start_time = random.uniform(0, max_start_time)
26
+ end_time = start_time + clip_duration
27
+ elif start_time is not None and end_time is not None:
28
+ start_time = float(start_time)
29
+ end_time = float(end_time)
30
+ clip_duration = int(end_time - start_time)
31
+ else:
32
+ raise ValueError("Either clip_duration must be provided, or both start_time and end_time must be specified.")
33
+
34
+ # Crop the video
35
+ cropped_video = video.subclip(start_time, end_time)
36
+
37
+
38
+ if center_crop:
39
+ # Calculate scale to ensure the desired crop size fits within the video
40
+ video_width, video_height = cropped_video.size
41
+ scale_width = video_width / width
42
+ scale_height = video_height / height
43
+ if longest_to_width:
44
+ scale = max(scale_width, scale_height)
45
+ else:
46
+ scale = min(scale_width, scale_height)
47
+
48
+ # Resize video to ensure the crop area fits within the frame
49
+ # This step ensures that the smallest dimension matches or exceeds 512 pixels
50
+ new_width = int(video_width / scale)
51
+ new_height = int(video_height / scale)
52
+ resized_video = cropped_video.resize(newsize=(new_width, new_height))
53
+ print(f"Resized video to ({new_width}, {new_height})")
54
+
55
+ # Calculate crop position with offset, ensuring the crop does not go out of bounds
56
+ # The offset calculation needs to ensure that the cropping area remains within the video frame
57
+ offset_x = int(((x_offset + 1) / 2) * (new_width - width)) # Adjusted for [-1, 1] scale
58
+ offset_y = int(((y_offset + 1) / 2) * (new_height - height)) # Adjusted for [-1, 1] scale
59
+
60
+ # Ensure offsets do not push the crop area out of the video frame
61
+ offset_x = max(0, min(new_width - width, offset_x))
62
+ offset_y = max(0, min(new_height - height, offset_y))
63
+
64
+ # Apply center crop with offsets
65
+ cropped_video = resized_video.crop(x1=offset_x, y1=offset_y, width=width, height=height)
66
+ elif width and height:
67
+ # Directly resize the video to specified width and height if no center crop is specified
68
+ cropped_video = cropped_video.resize(newsize=(width, height))
69
+
70
+
71
+ # After resizing and cropping, set the frame rate to fps
72
+ fps = n_frames // clip_duration
73
+ final_video = cropped_video.set_fps(fps)
74
+
75
+ # Prepare the output video path
76
+ if not os.path.exists(output_folder):
77
+ os.makedirs(output_folder)
78
+ filename = os.path.basename(input_video_path)
79
+ output_video_path = os.path.join(output_folder, filename)
80
+
81
+ # Write the result to the output file
82
+ final_video.write_videofile(output_video_path, codec='libx264', audio_codec='aac', fps=fps)
83
+ print(f"Processed {input_video_path}, saved to {output_video_path}")
84
+ return output_video_path
85
+
86
+
87
+ def infer_video_prompt(model, video_path, output_dir, prompt, prompt_type="instruct", force_512=False, seed=42, negative_prompt="", overwrite=False):
88
+ """
89
+ Processes videos from the input directory, resizes them to 512x512 before feeding into the model by first frame,
90
+ and saves the processed video back to its original size in the output directory.
91
+
92
+ Args:
93
+ model: The video editing model.
94
+ input_dir (str): Path to the directory containing input videos.
95
+ output_dir (str): Path to the directory where processed videos will be saved.
96
+ prompt (str): Instruction prompt for video editing.
97
+ """
98
+
99
+ # Create the output directory if it does not exist
100
+ if not os.path.exists(output_dir):
101
+ os.makedirs(output_dir)
102
+
103
+ video_clip = VideoFileClip(video_path)
104
+ video_filename = os.path.basename(video_path)
105
+ # filename_noext = os.path.splitext(video_filename)[0]
106
+
107
+ # Create the output directory if it does not exist
108
+ # final_output_dir = os.path.join(output_dir, filename_noext)
109
+ final_output_dir = output_dir
110
+ if not os.path.exists(final_output_dir):
111
+ os.makedirs(final_output_dir)
112
+
113
+ result_path = os.path.join(final_output_dir, prompt + ".png")
114
+
115
+ # Check if result already exists
116
+ if os.path.exists(result_path) and overwrite is False:
117
+ print(f"Result already exists: {result_path}")
118
+ return
119
+
120
+ def process_frame(image):
121
+ pil_image = Image.fromarray(image)
122
+ if force_512:
123
+ pil_image = pil_image.resize((512, 512), Image.LANCZOS)
124
+ if prompt_type == "instruct":
125
+ result = model.infer_one_image(pil_image, instruct_prompt=prompt, seed=seed, negative_prompt=negative_prompt)
126
+ else:
127
+ result = model.infer_one_image(pil_image, target_prompt=prompt, seed=seed, negative_prompt=negative_prompt)
128
+ if force_512:
129
+ result = result.resize(video_clip.size, Image.LANCZOS)
130
+ return np.array(result)
131
+
132
+ # Process only the first frame
133
+ first_frame = video_clip.get_frame(0) # Get the first frame
134
+ processed_frame = process_frame(first_frame) # Process the first frame
135
+
136
+
137
+ #Image.fromarray(first_frame).save(os.path.join(final_output_dir, "00000.png"))
138
+ Image.fromarray(processed_frame).save(result_path)
139
+ print(f"Processed and saved the first frame: {result_path}")
140
+ return result_path
141
+
142
+ def infer_video_style(model, video_path, output_dir, style_image, prompt, force_512=False, seed=42, negative_prompt="", overwrite=False):
143
+ if not os.path.exists(output_dir):
144
+ os.makedirs(output_dir)
145
+
146
+ video_clip = VideoFileClip(video_path)
147
+ video_filename = os.path.basename(video_path)
148
+ final_output_dir = output_dir
149
+ if not os.path.exists(final_output_dir):
150
+ os.makedirs(final_output_dir)
151
+
152
+ result_path = os.path.join(final_output_dir, "style" + ".png")
153
+ if os.path.exists(result_path) and overwrite is False:
154
+ print(f"Result already exists: {result_path}")
155
+ return
156
+ def process_frame(image):
157
+ pil_image = Image.fromarray(image)
158
+ if force_512:
159
+ pil_image = pil_image.resize((512, 512), Image.LANCZOS)
160
+ result = model.infer_one_image(pil_image,
161
+ style_image=style_image,
162
+ prompt=prompt,
163
+ seed=seed,
164
+ negative_prompt=negative_prompt)
165
+ if force_512:
166
+ result = result.resize(video_clip.size, Image.LANCZOS)
167
+ return np.array(result)
168
+ # Process only the first frame
169
+ first_frame = video_clip.get_frame(0) # Get the first frame
170
+ processed_frame = process_frame(first_frame) # Process the first frame
171
+ Image.fromarray(processed_frame).save(result_path)
172
+ print(f"Processed and saved the first frame: {result_path}")
173
+ return result_path