zhiweili commited on
Commit
8219169
1 Parent(s): 1ff5892

add app_haircolor_inpainting

Browse files
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_haircolor import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_haircolor_inpainting import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_haircolor.py CHANGED
@@ -19,7 +19,7 @@ from controlnet_aux import (
19
  CannyDetector,
20
  )
21
 
22
- BASE_MODEL = "SG161222/RealVisXL_V5.0_Lightning"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  DEFAULT_EDIT_PROMPT = "a woman, blue hair, high detailed"
 
19
  CannyDetector,
20
  )
21
 
22
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  DEFAULT_EDIT_PROMPT = "a woman, blue hair, high detailed"
app_haircolor_inpainting.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from segment_utils import(
8
+ segment_image_withmask,
9
+ restore_result,
10
+ )
11
+ from diffusers import (
12
+ DiffusionPipeline,
13
+ T2IAdapter,
14
+ MultiAdapter,
15
+ )
16
+
17
+ from controlnet_aux import (
18
+ LineartDetector,
19
+ CannyDetector,
20
+ )
21
+
22
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ DEFAULT_EDIT_PROMPT = "a woman, blue hair, high detailed"
26
+ DEFAULT_NEGATIVE_PROMPT = "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, poorly drawn face, bad face, fused face, ugly face, worst face, asymmetrical, unrealistic skin texture, bad proportions, out of frame, poorly drawn hands, cloned face, double face"
27
+
28
+ DEFAULT_CATEGORY = "hair"
29
+
30
+ lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
31
+ lineart_detector = lineart_detector.to(DEVICE)
32
+
33
+ canndy_detector = CannyDetector()
34
+
35
+ adapters = MultiAdapter(
36
+ [
37
+ T2IAdapter.from_pretrained(
38
+ "TencentARC/t2i-adapter-lineart-sdxl-1.0",
39
+ torch_dtype=torch.float16,
40
+ varient="fp16",
41
+ ),
42
+ T2IAdapter.from_pretrained(
43
+ "TencentARC/t2i-adapter-canny-sdxl-1.0",
44
+ torch_dtype=torch.float16,
45
+ varient="fp16",
46
+ ),
47
+ ]
48
+ )
49
+ adapters = adapters.to(torch.float16)
50
+
51
+ basepipeline = DiffusionPipeline.from_pretrained(
52
+ BASE_MODEL,
53
+ torch_dtype=torch.float16,
54
+ use_safetensors=True,
55
+ adapter=adapters,
56
+ custom_pipeline="./pipelines/pipelines/pipeline_sdxl_adapter_inpaint.py",
57
+ )
58
+
59
+ basepipeline = basepipeline.to(DEVICE)
60
+
61
+ basepipeline.enable_model_cpu_offload()
62
+
63
+ @spaces.GPU(duration=30)
64
+ def image_to_image(
65
+ input_image: Image,
66
+ mask_image: Image,
67
+ edit_prompt: str,
68
+ seed: int,
69
+ num_steps: int,
70
+ guidance_scale: float,
71
+ generate_size: int,
72
+ lineart_scale: float = 1.0,
73
+ canny_scale: float = 0.5,
74
+ ):
75
+ run_task_time = 0
76
+ time_cost_str = ''
77
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
78
+ lineart_image = lineart_detector(input_image, 384, generate_size)
79
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
80
+ canny_image = canndy_detector(input_image, 384, generate_size)
81
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
82
+
83
+ cond_image = [lineart_image, canny_image]
84
+ cond_scale = [lineart_scale, canny_scale]
85
+
86
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
87
+ generated_image = basepipeline(
88
+ generator=generator,
89
+ prompt=edit_prompt,
90
+ negative_prompt=DEFAULT_NEGATIVE_PROMPT,
91
+ image=input_image,
92
+ mask_image=mask_image,
93
+ height=generate_size,
94
+ width=generate_size,
95
+ guidance_scale=guidance_scale,
96
+ num_inference_steps=num_steps,
97
+ adapter_image=cond_image,
98
+ adapter_conditioning_scale=cond_scale,
99
+ ).images[0]
100
+
101
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
102
+
103
+ return generated_image, time_cost_str
104
+
105
+ def get_time_cost(run_task_time, time_cost_str):
106
+ now_time = int(time.time()*1000)
107
+ if run_task_time == 0:
108
+ time_cost_str = 'start'
109
+ else:
110
+ if time_cost_str != '':
111
+ time_cost_str += f'-->'
112
+ time_cost_str += f'{now_time - run_task_time}'
113
+ run_task_time = now_time
114
+ return run_task_time, time_cost_str
115
+
116
+ def create_demo() -> gr.Blocks:
117
+ with gr.Blocks() as demo:
118
+ croper = gr.State()
119
+ with gr.Row():
120
+ with gr.Column():
121
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
122
+ generate_size = gr.Number(label="Generate Size", value=1024)
123
+ seed = gr.Number(label="Seed", value=8)
124
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
125
+ with gr.Column():
126
+ num_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Num Steps")
127
+ guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
128
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
129
+ with gr.Column():
130
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
131
+ lineart_scale = gr.Slider(minimum=0, maximum=2, value=1, step=0.1, label="Lineart Scale")
132
+ canny_scale = gr.Slider(minimum=0, maximum=2, value=0.5, step=0.1, label="Canny Scale")
133
+ g_btn = gr.Button("Edit Image")
134
+
135
+ with gr.Row():
136
+ with gr.Column():
137
+ input_image = gr.Image(label="Input Image", type="pil")
138
+ with gr.Column():
139
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
140
+ with gr.Column():
141
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
142
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
143
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
144
+ mask_image = gr.Image(label="Mask Image", type="pil", interactive=False)
145
+
146
+ g_btn.click(
147
+ fn=segment_image_withmask,
148
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
149
+ outputs=[origin_area_image, mask_image, croper],
150
+ ).success(
151
+ fn=image_to_image,
152
+ inputs=[origin_area_image, mask_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, lineart_scale, canny_scale],
153
+ outputs=[generated_image, generated_cost],
154
+ ).success(
155
+ fn=restore_result,
156
+ inputs=[croper, category, generated_image],
157
+ outputs=[restored_image],
158
+ )
159
+
160
+ return demo
pipelines/pipeline_sdxl_adapter_inpaint.py ADDED
@@ -0,0 +1,1834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from transformers import (
22
+ CLIPImageProcessor,
23
+ CLIPTextModel,
24
+ CLIPTextModelWithProjection,
25
+ CLIPTokenizer,
26
+ CLIPVisionModelWithProjection,
27
+ )
28
+
29
+ from diffusers.callbacks import (
30
+ MultiPipelineCallbacks,
31
+ PipelineCallback,
32
+ )
33
+
34
+ from diffusers.image_processor import (
35
+ PipelineImageInput,
36
+ VaeImageProcessor,
37
+ )
38
+
39
+ from diffusers.loaders import (
40
+ FromSingleFileMixin,
41
+ IPAdapterMixin,
42
+ StableDiffusionXLLoraLoaderMixin,
43
+ TextualInversionLoaderMixin,
44
+ )
45
+
46
+ from diffusers.models import (
47
+ AutoencoderKL,
48
+ ImageProjection,
49
+ MultiAdapter,
50
+ T2IAdapter,
51
+ UNet2DConditionModel,
52
+ )
53
+
54
+ from diffusers.models.attention_processor import (
55
+ AttnProcessor2_0,
56
+ XFormersAttnProcessor,
57
+ )
58
+
59
+ from diffusers.models.lora import (
60
+ adjust_lora_scale_text_encoder,
61
+ )
62
+
63
+ from diffusers.schedulers import (
64
+ KarrasDiffusionSchedulers,
65
+ )
66
+
67
+ from diffusers.utils import (
68
+ PIL_INTERPOLATION,
69
+ USE_PEFT_BACKEND,
70
+ deprecate,
71
+ is_invisible_watermark_available,
72
+ is_torch_xla_available,
73
+ logging,
74
+ replace_example_docstring,
75
+ scale_lora_layers,
76
+ unscale_lora_layers,
77
+ )
78
+
79
+ from diffusers.utils.torch_utils import (
80
+ randn_tensor,
81
+ )
82
+
83
+ from diffusers.pipelines.pipeline_utils import (
84
+ DiffusionPipeline,
85
+ StableDiffusionMixin,
86
+ )
87
+
88
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
89
+ StableDiffusionXLPipelineOutput,
90
+ )
91
+
92
+ if is_invisible_watermark_available():
93
+ from diffusers.pipelines.stable_diffusion_xl.watermark import (
94
+ StableDiffusionXLWatermarker,
95
+ )
96
+
97
+ if is_torch_xla_available():
98
+ import torch_xla.core.xla_model as xm
99
+
100
+ XLA_AVAILABLE = True
101
+ else:
102
+ XLA_AVAILABLE = False
103
+
104
+
105
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
106
+
107
+
108
+ EXAMPLE_DOC_STRING = """
109
+ Examples:
110
+ ```py
111
+ >>> import torch
112
+ >>> from diffusers import StableDiffusionXLInpaintPipeline
113
+ >>> from diffusers.utils import load_image
114
+
115
+ >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
116
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
117
+ ... torch_dtype=torch.float16,
118
+ ... variant="fp16",
119
+ ... use_safetensors=True,
120
+ ... )
121
+ >>> pipe.to("cuda")
122
+
123
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
124
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
125
+
126
+ >>> init_image = load_image(img_url).convert("RGB")
127
+ >>> mask_image = load_image(mask_url).convert("RGB")
128
+
129
+ >>> prompt = "A majestic tiger sitting on a bench"
130
+ >>> image = pipe(
131
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
132
+ ... ).images[0]
133
+ ```
134
+ """
135
+
136
+
137
+ def _preprocess_adapter_image(image, height, width):
138
+ if isinstance(image, torch.Tensor):
139
+ return image
140
+ elif isinstance(image, PIL.Image.Image):
141
+ image = [image]
142
+
143
+ if isinstance(image[0], PIL.Image.Image):
144
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
145
+ image = [
146
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
147
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
148
+ image = np.concatenate(image, axis=0)
149
+ image = np.array(image).astype(np.float32) / 255.0
150
+ image = image.transpose(0, 3, 1, 2)
151
+ image = torch.from_numpy(image)
152
+ elif isinstance(image[0], torch.Tensor):
153
+ if image[0].ndim == 3:
154
+ image = torch.stack(image, dim=0)
155
+ elif image[0].ndim == 4:
156
+ image = torch.cat(image, dim=0)
157
+ else:
158
+ raise ValueError(
159
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
160
+ )
161
+ return image
162
+
163
+
164
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
165
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
166
+ """
167
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
168
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
169
+ """
170
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
171
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
172
+ # rescale the results from guidance (fixes overexposure)
173
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
174
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
175
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
176
+ return noise_cfg
177
+
178
+
179
+ def mask_pil_to_torch(mask, height, width):
180
+ # preprocess mask
181
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
182
+ mask = [mask]
183
+
184
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
185
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
186
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
187
+ mask = mask.astype(np.float32) / 255.0
188
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
189
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
190
+
191
+ mask = torch.from_numpy(mask)
192
+ return mask
193
+
194
+
195
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
196
+ def retrieve_latents(
197
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
198
+ ):
199
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
200
+ return encoder_output.latent_dist.sample(generator)
201
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
202
+ return encoder_output.latent_dist.mode()
203
+ elif hasattr(encoder_output, "latents"):
204
+ return encoder_output.latents
205
+ else:
206
+ raise AttributeError("Could not access latents of provided encoder_output")
207
+
208
+
209
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
210
+ def retrieve_timesteps(
211
+ scheduler,
212
+ num_inference_steps: Optional[int] = None,
213
+ device: Optional[Union[str, torch.device]] = None,
214
+ timesteps: Optional[List[int]] = None,
215
+ sigmas: Optional[List[float]] = None,
216
+ **kwargs,
217
+ ):
218
+ """
219
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
220
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
221
+
222
+ Args:
223
+ scheduler (`SchedulerMixin`):
224
+ The scheduler to get timesteps from.
225
+ num_inference_steps (`int`):
226
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
227
+ must be `None`.
228
+ device (`str` or `torch.device`, *optional*):
229
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
230
+ timesteps (`List[int]`, *optional*):
231
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
232
+ `num_inference_steps` and `sigmas` must be `None`.
233
+ sigmas (`List[float]`, *optional*):
234
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
235
+ `num_inference_steps` and `timesteps` must be `None`.
236
+
237
+ Returns:
238
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
239
+ second element is the number of inference steps.
240
+ """
241
+ if timesteps is not None and sigmas is not None:
242
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
243
+ if timesteps is not None:
244
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
245
+ if not accepts_timesteps:
246
+ raise ValueError(
247
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
248
+ f" timestep schedules. Please check whether you are using the correct scheduler."
249
+ )
250
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
251
+ timesteps = scheduler.timesteps
252
+ num_inference_steps = len(timesteps)
253
+ elif sigmas is not None:
254
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
255
+ if not accept_sigmas:
256
+ raise ValueError(
257
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
258
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
259
+ )
260
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
261
+ timesteps = scheduler.timesteps
262
+ num_inference_steps = len(timesteps)
263
+ else:
264
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
265
+ timesteps = scheduler.timesteps
266
+ return timesteps, num_inference_steps
267
+
268
+
269
+ class StableDiffusionXLInpaintPipeline(
270
+ DiffusionPipeline,
271
+ StableDiffusionMixin,
272
+ TextualInversionLoaderMixin,
273
+ StableDiffusionXLLoraLoaderMixin,
274
+ FromSingleFileMixin,
275
+ IPAdapterMixin,
276
+ ):
277
+ r"""
278
+ Pipeline for text-to-image generation using Stable Diffusion XL.
279
+
280
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
281
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
282
+
283
+ The pipeline also inherits the following loading methods:
284
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
285
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
286
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
287
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
288
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
289
+
290
+ Args:
291
+ vae ([`AutoencoderKL`]):
292
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
293
+ text_encoder ([`CLIPTextModel`]):
294
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
295
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
296
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
297
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
298
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
299
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
300
+ specifically the
301
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
302
+ variant.
303
+ tokenizer (`CLIPTokenizer`):
304
+ Tokenizer of class
305
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
306
+ tokenizer_2 (`CLIPTokenizer`):
307
+ Second Tokenizer of class
308
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
309
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
310
+ scheduler ([`SchedulerMixin`]):
311
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
312
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
313
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
314
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
315
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
316
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
317
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
318
+ `stabilityai/stable-diffusion-xl-base-1-0`.
319
+ add_watermarker (`bool`, *optional*):
320
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
321
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
322
+ watermarker will be used.
323
+ """
324
+
325
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
326
+
327
+ _optional_components = [
328
+ "tokenizer",
329
+ "tokenizer_2",
330
+ "text_encoder",
331
+ "text_encoder_2",
332
+ "image_encoder",
333
+ "feature_extractor",
334
+ ]
335
+ _callback_tensor_inputs = [
336
+ "latents",
337
+ "prompt_embeds",
338
+ "negative_prompt_embeds",
339
+ "add_text_embeds",
340
+ "add_time_ids",
341
+ "negative_pooled_prompt_embeds",
342
+ "add_neg_time_ids",
343
+ "mask",
344
+ "masked_image_latents",
345
+ ]
346
+
347
+ def __init__(
348
+ self,
349
+ vae: AutoencoderKL,
350
+ text_encoder: CLIPTextModel,
351
+ text_encoder_2: CLIPTextModelWithProjection,
352
+ tokenizer: CLIPTokenizer,
353
+ tokenizer_2: CLIPTokenizer,
354
+ unet: UNet2DConditionModel,
355
+ adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
356
+ scheduler: KarrasDiffusionSchedulers,
357
+ image_encoder: CLIPVisionModelWithProjection = None,
358
+ feature_extractor: CLIPImageProcessor = None,
359
+ requires_aesthetics_score: bool = False,
360
+ force_zeros_for_empty_prompt: bool = True,
361
+ add_watermarker: Optional[bool] = None,
362
+ ):
363
+ super().__init__()
364
+
365
+ self.register_modules(
366
+ vae=vae,
367
+ text_encoder=text_encoder,
368
+ text_encoder_2=text_encoder_2,
369
+ tokenizer=tokenizer,
370
+ tokenizer_2=tokenizer_2,
371
+ unet=unet,
372
+ adapter=adapter,
373
+ image_encoder=image_encoder,
374
+ feature_extractor=feature_extractor,
375
+ scheduler=scheduler,
376
+ )
377
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
378
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
379
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
380
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
381
+ self.mask_processor = VaeImageProcessor(
382
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
383
+ )
384
+
385
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
386
+
387
+ if add_watermarker:
388
+ self.watermark = StableDiffusionXLWatermarker()
389
+ else:
390
+ self.watermark = None
391
+
392
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
393
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
394
+ dtype = next(self.image_encoder.parameters()).dtype
395
+
396
+ if not isinstance(image, torch.Tensor):
397
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
398
+
399
+ image = image.to(device=device, dtype=dtype)
400
+ if output_hidden_states:
401
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
402
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
403
+ uncond_image_enc_hidden_states = self.image_encoder(
404
+ torch.zeros_like(image), output_hidden_states=True
405
+ ).hidden_states[-2]
406
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
407
+ num_images_per_prompt, dim=0
408
+ )
409
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
410
+ else:
411
+ image_embeds = self.image_encoder(image).image_embeds
412
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
413
+ uncond_image_embeds = torch.zeros_like(image_embeds)
414
+
415
+ return image_embeds, uncond_image_embeds
416
+
417
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
418
+ def prepare_ip_adapter_image_embeds(
419
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
420
+ ):
421
+ image_embeds = []
422
+ if do_classifier_free_guidance:
423
+ negative_image_embeds = []
424
+ if ip_adapter_image_embeds is None:
425
+ if not isinstance(ip_adapter_image, list):
426
+ ip_adapter_image = [ip_adapter_image]
427
+
428
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
429
+ raise ValueError(
430
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
431
+ )
432
+
433
+ for single_ip_adapter_image, image_proj_layer in zip(
434
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
435
+ ):
436
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
437
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
438
+ single_ip_adapter_image, device, 1, output_hidden_state
439
+ )
440
+
441
+ image_embeds.append(single_image_embeds[None, :])
442
+ if do_classifier_free_guidance:
443
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
444
+ else:
445
+ for single_image_embeds in ip_adapter_image_embeds:
446
+ if do_classifier_free_guidance:
447
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
448
+ negative_image_embeds.append(single_negative_image_embeds)
449
+ image_embeds.append(single_image_embeds)
450
+
451
+ ip_adapter_image_embeds = []
452
+ for i, single_image_embeds in enumerate(image_embeds):
453
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
454
+ if do_classifier_free_guidance:
455
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
456
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
457
+
458
+ single_image_embeds = single_image_embeds.to(device=device)
459
+ ip_adapter_image_embeds.append(single_image_embeds)
460
+
461
+ return ip_adapter_image_embeds
462
+
463
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
464
+ def encode_prompt(
465
+ self,
466
+ prompt: str,
467
+ prompt_2: Optional[str] = None,
468
+ device: Optional[torch.device] = None,
469
+ num_images_per_prompt: int = 1,
470
+ do_classifier_free_guidance: bool = True,
471
+ negative_prompt: Optional[str] = None,
472
+ negative_prompt_2: Optional[str] = None,
473
+ prompt_embeds: Optional[torch.Tensor] = None,
474
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
475
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
476
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
477
+ lora_scale: Optional[float] = None,
478
+ clip_skip: Optional[int] = None,
479
+ ):
480
+ r"""
481
+ Encodes the prompt into text encoder hidden states.
482
+
483
+ Args:
484
+ prompt (`str` or `List[str]`, *optional*):
485
+ prompt to be encoded
486
+ prompt_2 (`str` or `List[str]`, *optional*):
487
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
488
+ used in both text-encoders
489
+ device: (`torch.device`):
490
+ torch device
491
+ num_images_per_prompt (`int`):
492
+ number of images that should be generated per prompt
493
+ do_classifier_free_guidance (`bool`):
494
+ whether to use classifier free guidance or not
495
+ negative_prompt (`str` or `List[str]`, *optional*):
496
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
497
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
498
+ less than `1`).
499
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
500
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
501
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
502
+ prompt_embeds (`torch.Tensor`, *optional*):
503
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
504
+ provided, text embeddings will be generated from `prompt` input argument.
505
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
506
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
507
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
508
+ argument.
509
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
510
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
511
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
512
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
513
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
514
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
515
+ input argument.
516
+ lora_scale (`float`, *optional*):
517
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
518
+ clip_skip (`int`, *optional*):
519
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
520
+ the output of the pre-final layer will be used for computing the prompt embeddings.
521
+ """
522
+ device = device or self._execution_device
523
+
524
+ # set lora scale so that monkey patched LoRA
525
+ # function of text encoder can correctly access it
526
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
527
+ self._lora_scale = lora_scale
528
+
529
+ # dynamically adjust the LoRA scale
530
+ if self.text_encoder is not None:
531
+ if not USE_PEFT_BACKEND:
532
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
533
+ else:
534
+ scale_lora_layers(self.text_encoder, lora_scale)
535
+
536
+ if self.text_encoder_2 is not None:
537
+ if not USE_PEFT_BACKEND:
538
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
539
+ else:
540
+ scale_lora_layers(self.text_encoder_2, lora_scale)
541
+
542
+ prompt = [prompt] if isinstance(prompt, str) else prompt
543
+
544
+ if prompt is not None:
545
+ batch_size = len(prompt)
546
+ else:
547
+ batch_size = prompt_embeds.shape[0]
548
+
549
+ # Define tokenizers and text encoders
550
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
551
+ text_encoders = (
552
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
553
+ )
554
+
555
+ if prompt_embeds is None:
556
+ prompt_2 = prompt_2 or prompt
557
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
558
+
559
+ # textual inversion: process multi-vector tokens if necessary
560
+ prompt_embeds_list = []
561
+ prompts = [prompt, prompt_2]
562
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
563
+ if isinstance(self, TextualInversionLoaderMixin):
564
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
565
+
566
+ text_inputs = tokenizer(
567
+ prompt,
568
+ padding="max_length",
569
+ max_length=tokenizer.model_max_length,
570
+ truncation=True,
571
+ return_tensors="pt",
572
+ )
573
+
574
+ text_input_ids = text_inputs.input_ids
575
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
576
+
577
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
578
+ text_input_ids, untruncated_ids
579
+ ):
580
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
581
+ logger.warning(
582
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
583
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
584
+ )
585
+
586
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
587
+
588
+ # We are only ALWAYS interested in the pooled output of the final text encoder
589
+ pooled_prompt_embeds = prompt_embeds[0]
590
+ if clip_skip is None:
591
+ prompt_embeds = prompt_embeds.hidden_states[-2]
592
+ else:
593
+ # "2" because SDXL always indexes from the penultimate layer.
594
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
595
+
596
+ prompt_embeds_list.append(prompt_embeds)
597
+
598
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
599
+
600
+ # get unconditional embeddings for classifier free guidance
601
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
602
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
603
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
604
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
605
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
606
+ negative_prompt = negative_prompt or ""
607
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
608
+
609
+ # normalize str to list
610
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
611
+ negative_prompt_2 = (
612
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
613
+ )
614
+
615
+ uncond_tokens: List[str]
616
+ if prompt is not None and type(prompt) is not type(negative_prompt):
617
+ raise TypeError(
618
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
619
+ f" {type(prompt)}."
620
+ )
621
+ elif batch_size != len(negative_prompt):
622
+ raise ValueError(
623
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
624
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
625
+ " the batch size of `prompt`."
626
+ )
627
+ else:
628
+ uncond_tokens = [negative_prompt, negative_prompt_2]
629
+
630
+ negative_prompt_embeds_list = []
631
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
632
+ if isinstance(self, TextualInversionLoaderMixin):
633
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
634
+
635
+ max_length = prompt_embeds.shape[1]
636
+ uncond_input = tokenizer(
637
+ negative_prompt,
638
+ padding="max_length",
639
+ max_length=max_length,
640
+ truncation=True,
641
+ return_tensors="pt",
642
+ )
643
+
644
+ negative_prompt_embeds = text_encoder(
645
+ uncond_input.input_ids.to(device),
646
+ output_hidden_states=True,
647
+ )
648
+ # We are only ALWAYS interested in the pooled output of the final text encoder
649
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
650
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
651
+
652
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
653
+
654
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
655
+
656
+ if self.text_encoder_2 is not None:
657
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
658
+ else:
659
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
660
+
661
+ bs_embed, seq_len, _ = prompt_embeds.shape
662
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
663
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
664
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
665
+
666
+ if do_classifier_free_guidance:
667
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
668
+ seq_len = negative_prompt_embeds.shape[1]
669
+
670
+ if self.text_encoder_2 is not None:
671
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
672
+ else:
673
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
674
+
675
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
676
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
677
+
678
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
679
+ bs_embed * num_images_per_prompt, -1
680
+ )
681
+ if do_classifier_free_guidance:
682
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
683
+ bs_embed * num_images_per_prompt, -1
684
+ )
685
+
686
+ if self.text_encoder is not None:
687
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
688
+ # Retrieve the original scale by scaling back the LoRA layers
689
+ unscale_lora_layers(self.text_encoder, lora_scale)
690
+
691
+ if self.text_encoder_2 is not None:
692
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
693
+ # Retrieve the original scale by scaling back the LoRA layers
694
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
695
+
696
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
697
+
698
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
699
+ def prepare_extra_step_kwargs(self, generator, eta):
700
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
701
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
702
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
703
+ # and should be between [0, 1]
704
+
705
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
706
+ extra_step_kwargs = {}
707
+ if accepts_eta:
708
+ extra_step_kwargs["eta"] = eta
709
+
710
+ # check if the scheduler accepts generator
711
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
712
+ if accepts_generator:
713
+ extra_step_kwargs["generator"] = generator
714
+ return extra_step_kwargs
715
+
716
+ def check_inputs(
717
+ self,
718
+ prompt,
719
+ prompt_2,
720
+ image,
721
+ mask_image,
722
+ height,
723
+ width,
724
+ strength,
725
+ callback_steps,
726
+ output_type,
727
+ negative_prompt=None,
728
+ negative_prompt_2=None,
729
+ prompt_embeds=None,
730
+ negative_prompt_embeds=None,
731
+ ip_adapter_image=None,
732
+ ip_adapter_image_embeds=None,
733
+ callback_on_step_end_tensor_inputs=None,
734
+ padding_mask_crop=None,
735
+ ):
736
+ if strength < 0 or strength > 1:
737
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
738
+
739
+ if height % 8 != 0 or width % 8 != 0:
740
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
741
+
742
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
743
+ raise ValueError(
744
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
745
+ f" {type(callback_steps)}."
746
+ )
747
+
748
+ if callback_on_step_end_tensor_inputs is not None and not all(
749
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
750
+ ):
751
+ raise ValueError(
752
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
753
+ )
754
+
755
+ if prompt is not None and prompt_embeds is not None:
756
+ raise ValueError(
757
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
758
+ " only forward one of the two."
759
+ )
760
+ elif prompt_2 is not None and prompt_embeds is not None:
761
+ raise ValueError(
762
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
763
+ " only forward one of the two."
764
+ )
765
+ elif prompt is None and prompt_embeds is None:
766
+ raise ValueError(
767
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
768
+ )
769
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
770
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
771
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
772
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
773
+
774
+ if negative_prompt is not None and negative_prompt_embeds is not None:
775
+ raise ValueError(
776
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
777
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
778
+ )
779
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
780
+ raise ValueError(
781
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
782
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
783
+ )
784
+
785
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
786
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
787
+ raise ValueError(
788
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
789
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
790
+ f" {negative_prompt_embeds.shape}."
791
+ )
792
+ if padding_mask_crop is not None:
793
+ if not isinstance(image, PIL.Image.Image):
794
+ raise ValueError(
795
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
796
+ )
797
+ if not isinstance(mask_image, PIL.Image.Image):
798
+ raise ValueError(
799
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
800
+ f" {type(mask_image)}."
801
+ )
802
+ if output_type != "pil":
803
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
804
+
805
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
806
+ raise ValueError(
807
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
808
+ )
809
+
810
+ if ip_adapter_image_embeds is not None:
811
+ if not isinstance(ip_adapter_image_embeds, list):
812
+ raise ValueError(
813
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
814
+ )
815
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
816
+ raise ValueError(
817
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
818
+ )
819
+
820
+ def prepare_latents(
821
+ self,
822
+ batch_size,
823
+ num_channels_latents,
824
+ height,
825
+ width,
826
+ dtype,
827
+ device,
828
+ generator,
829
+ latents=None,
830
+ image=None,
831
+ timestep=None,
832
+ is_strength_max=True,
833
+ add_noise=True,
834
+ return_noise=False,
835
+ return_image_latents=False,
836
+ ):
837
+ shape = (
838
+ batch_size,
839
+ num_channels_latents,
840
+ int(height) // self.vae_scale_factor,
841
+ int(width) // self.vae_scale_factor,
842
+ )
843
+ if isinstance(generator, list) and len(generator) != batch_size:
844
+ raise ValueError(
845
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
846
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
847
+ )
848
+
849
+ if (image is None or timestep is None) and not is_strength_max:
850
+ raise ValueError(
851
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
852
+ "However, either the image or the noise timestep has not been provided."
853
+ )
854
+
855
+ if image.shape[1] == 4:
856
+ image_latents = image.to(device=device, dtype=dtype)
857
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
858
+ elif return_image_latents or (latents is None and not is_strength_max):
859
+ image = image.to(device=device, dtype=dtype)
860
+ image_latents = self._encode_vae_image(image=image, generator=generator)
861
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
862
+
863
+ if latents is None and add_noise:
864
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
865
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
866
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
867
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
868
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
869
+ elif add_noise:
870
+ noise = latents.to(device)
871
+ latents = noise * self.scheduler.init_noise_sigma
872
+ else:
873
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
874
+ latents = image_latents.to(device)
875
+
876
+ outputs = (latents,)
877
+
878
+ if return_noise:
879
+ outputs += (noise,)
880
+
881
+ if return_image_latents:
882
+ outputs += (image_latents,)
883
+
884
+ return outputs
885
+
886
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
887
+ dtype = image.dtype
888
+ if self.vae.config.force_upcast:
889
+ image = image.float()
890
+ self.vae.to(dtype=torch.float32)
891
+
892
+ if isinstance(generator, list):
893
+ image_latents = [
894
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
895
+ for i in range(image.shape[0])
896
+ ]
897
+ image_latents = torch.cat(image_latents, dim=0)
898
+ else:
899
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
900
+
901
+ if self.vae.config.force_upcast:
902
+ self.vae.to(dtype)
903
+
904
+ image_latents = image_latents.to(dtype)
905
+ image_latents = self.vae.config.scaling_factor * image_latents
906
+
907
+ return image_latents
908
+
909
+ def prepare_mask_latents(
910
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
911
+ ):
912
+ # resize the mask to latents shape as we concatenate the mask to the latents
913
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
914
+ # and half precision
915
+ mask = torch.nn.functional.interpolate(
916
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
917
+ )
918
+ mask = mask.to(device=device, dtype=dtype)
919
+
920
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
921
+ if mask.shape[0] < batch_size:
922
+ if not batch_size % mask.shape[0] == 0:
923
+ raise ValueError(
924
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
925
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
926
+ " of masks that you pass is divisible by the total requested batch size."
927
+ )
928
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
929
+
930
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
931
+
932
+ if masked_image is not None and masked_image.shape[1] == 4:
933
+ masked_image_latents = masked_image
934
+ else:
935
+ masked_image_latents = None
936
+
937
+ if masked_image is not None:
938
+ if masked_image_latents is None:
939
+ masked_image = masked_image.to(device=device, dtype=dtype)
940
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
941
+
942
+ if masked_image_latents.shape[0] < batch_size:
943
+ if not batch_size % masked_image_latents.shape[0] == 0:
944
+ raise ValueError(
945
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
946
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
947
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
948
+ )
949
+ masked_image_latents = masked_image_latents.repeat(
950
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
951
+ )
952
+
953
+ masked_image_latents = (
954
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
955
+ )
956
+
957
+ # aligning device to prevent device errors when concating it with the latent model input
958
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
959
+
960
+ return mask, masked_image_latents
961
+
962
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
963
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
964
+ # get the original timestep using init_timestep
965
+ if denoising_start is None:
966
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
967
+ t_start = max(num_inference_steps - init_timestep, 0)
968
+
969
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
970
+ if hasattr(self.scheduler, "set_begin_index"):
971
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
972
+
973
+ return timesteps, num_inference_steps - t_start
974
+
975
+ else:
976
+ # Strength is irrelevant if we directly request a timestep to start at;
977
+ # that is, strength is determined by the denoising_start instead.
978
+ discrete_timestep_cutoff = int(
979
+ round(
980
+ self.scheduler.config.num_train_timesteps
981
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
982
+ )
983
+ )
984
+
985
+ num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
986
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
987
+ # if the scheduler is a 2nd order scheduler we might have to do +1
988
+ # because `num_inference_steps` might be even given that every timestep
989
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
990
+ # mean that we cut the timesteps in the middle of the denoising step
991
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
992
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
993
+ num_inference_steps = num_inference_steps + 1
994
+
995
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
996
+ t_start = len(self.scheduler.timesteps) - num_inference_steps
997
+ timesteps = self.scheduler.timesteps[t_start:]
998
+ if hasattr(self.scheduler, "set_begin_index"):
999
+ self.scheduler.set_begin_index(t_start)
1000
+ return timesteps, num_inference_steps
1001
+
1002
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
1003
+ def _get_add_time_ids(
1004
+ self,
1005
+ original_size,
1006
+ crops_coords_top_left,
1007
+ target_size,
1008
+ aesthetic_score,
1009
+ negative_aesthetic_score,
1010
+ negative_original_size,
1011
+ negative_crops_coords_top_left,
1012
+ negative_target_size,
1013
+ dtype,
1014
+ text_encoder_projection_dim=None,
1015
+ ):
1016
+ if self.config.requires_aesthetics_score:
1017
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
1018
+ add_neg_time_ids = list(
1019
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
1020
+ )
1021
+ else:
1022
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1023
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
1024
+
1025
+ passed_add_embed_dim = (
1026
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
1027
+ )
1028
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
1029
+
1030
+ if (
1031
+ expected_add_embed_dim > passed_add_embed_dim
1032
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
1033
+ ):
1034
+ raise ValueError(
1035
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
1036
+ )
1037
+ elif (
1038
+ expected_add_embed_dim < passed_add_embed_dim
1039
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
1040
+ ):
1041
+ raise ValueError(
1042
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
1043
+ )
1044
+ elif expected_add_embed_dim != passed_add_embed_dim:
1045
+ raise ValueError(
1046
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
1047
+ )
1048
+
1049
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
1050
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
1051
+
1052
+ return add_time_ids, add_neg_time_ids
1053
+
1054
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
1055
+ def upcast_vae(self):
1056
+ dtype = self.vae.dtype
1057
+ self.vae.to(dtype=torch.float32)
1058
+ use_torch_2_0_or_xformers = isinstance(
1059
+ self.vae.decoder.mid_block.attentions[0].processor,
1060
+ (
1061
+ AttnProcessor2_0,
1062
+ XFormersAttnProcessor,
1063
+ ),
1064
+ )
1065
+ # if xformers or torch_2_0 is used attention block does not need
1066
+ # to be in float32 which can save lots of memory
1067
+ if use_torch_2_0_or_xformers:
1068
+ self.vae.post_quant_conv.to(dtype)
1069
+ self.vae.decoder.conv_in.to(dtype)
1070
+ self.vae.decoder.mid_block.to(dtype)
1071
+
1072
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1073
+ def get_guidance_scale_embedding(
1074
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
1075
+ ) -> torch.Tensor:
1076
+ """
1077
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1078
+
1079
+ Args:
1080
+ w (`torch.Tensor`):
1081
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
1082
+ embedding_dim (`int`, *optional*, defaults to 512):
1083
+ Dimension of the embeddings to generate.
1084
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
1085
+ Data type of the generated embeddings.
1086
+
1087
+ Returns:
1088
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
1089
+ """
1090
+ assert len(w.shape) == 1
1091
+ w = w * 1000.0
1092
+
1093
+ half_dim = embedding_dim // 2
1094
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1095
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1096
+ emb = w.to(dtype)[:, None] * emb[None, :]
1097
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1098
+ if embedding_dim % 2 == 1: # zero pad
1099
+ emb = torch.nn.functional.pad(emb, (0, 1))
1100
+ assert emb.shape == (w.shape[0], embedding_dim)
1101
+ return emb
1102
+
1103
+ @property
1104
+ def guidance_scale(self):
1105
+ return self._guidance_scale
1106
+
1107
+ @property
1108
+ def guidance_rescale(self):
1109
+ return self._guidance_rescale
1110
+
1111
+ @property
1112
+ def clip_skip(self):
1113
+ return self._clip_skip
1114
+
1115
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1116
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1117
+ # corresponds to doing no classifier free guidance.
1118
+ @property
1119
+ def do_classifier_free_guidance(self):
1120
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1121
+
1122
+ @property
1123
+ def cross_attention_kwargs(self):
1124
+ return self._cross_attention_kwargs
1125
+
1126
+ @property
1127
+ def denoising_end(self):
1128
+ return self._denoising_end
1129
+
1130
+ @property
1131
+ def denoising_start(self):
1132
+ return self._denoising_start
1133
+
1134
+ @property
1135
+ def num_timesteps(self):
1136
+ return self._num_timesteps
1137
+
1138
+ @property
1139
+ def interrupt(self):
1140
+ return self._interrupt
1141
+
1142
+ @torch.no_grad()
1143
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1144
+ def __call__(
1145
+ self,
1146
+ prompt: Union[str, List[str]] = None,
1147
+ prompt_2: Optional[Union[str, List[str]]] = None,
1148
+ image: PipelineImageInput = None,
1149
+ mask_image: PipelineImageInput = None,
1150
+ masked_image_latents: torch.Tensor = None,
1151
+ height: Optional[int] = None,
1152
+ width: Optional[int] = None,
1153
+ adapter_image: PipelineImageInput = None,
1154
+ padding_mask_crop: Optional[int] = None,
1155
+ strength: float = 0.9999,
1156
+ num_inference_steps: int = 50,
1157
+ timesteps: List[int] = None,
1158
+ sigmas: List[float] = None,
1159
+ denoising_start: Optional[float] = None,
1160
+ denoising_end: Optional[float] = None,
1161
+ guidance_scale: float = 7.5,
1162
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1163
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1164
+ num_images_per_prompt: Optional[int] = 1,
1165
+ eta: float = 0.0,
1166
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1167
+ latents: Optional[torch.Tensor] = None,
1168
+ prompt_embeds: Optional[torch.Tensor] = None,
1169
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1170
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1171
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1172
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1173
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1174
+ output_type: Optional[str] = "pil",
1175
+ return_dict: bool = True,
1176
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1177
+ guidance_rescale: float = 0.0,
1178
+ original_size: Tuple[int, int] = None,
1179
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1180
+ target_size: Tuple[int, int] = None,
1181
+ negative_original_size: Optional[Tuple[int, int]] = None,
1182
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1183
+ negative_target_size: Optional[Tuple[int, int]] = None,
1184
+ aesthetic_score: float = 6.0,
1185
+ negative_aesthetic_score: float = 2.5,
1186
+ adapter_conditioning_scale: Union[float, List[float]] = 1.0,
1187
+ adapter_conditioning_factor: float = 1.0,
1188
+ clip_skip: Optional[int] = None,
1189
+ callback_on_step_end: Optional[
1190
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1191
+ ] = None,
1192
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1193
+ **kwargs,
1194
+ ):
1195
+ r"""
1196
+ Function invoked when calling the pipeline for generation.
1197
+
1198
+ Args:
1199
+ prompt (`str` or `List[str]`, *optional*):
1200
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1201
+ instead.
1202
+ prompt_2 (`str` or `List[str]`, *optional*):
1203
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1204
+ used in both text-encoders
1205
+ image (`PIL.Image.Image`):
1206
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
1207
+ be masked out with `mask_image` and repainted according to `prompt`.
1208
+ mask_image (`PIL.Image.Image`):
1209
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1210
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
1211
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
1212
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
1213
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1214
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1215
+ Anything below 512 pixels won't work well for
1216
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1217
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1218
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1219
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1220
+ Anything below 512 pixels won't work well for
1221
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1222
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1223
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
1224
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
1225
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
1226
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
1227
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1228
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
1229
+ the image is large and contain information irrelevant for inpainting, such as background.
1230
+ strength (`float`, *optional*, defaults to 0.9999):
1231
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1232
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1233
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
1234
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1235
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1236
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
1237
+ integer, the value of `strength` will be ignored.
1238
+ num_inference_steps (`int`, *optional*, defaults to 50):
1239
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1240
+ expense of slower inference.
1241
+ timesteps (`List[int]`, *optional*):
1242
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1243
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1244
+ passed will be used. Must be in descending order.
1245
+ sigmas (`List[float]`, *optional*):
1246
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1247
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1248
+ will be used.
1249
+ denoising_start (`float`, *optional*):
1250
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
1251
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
1252
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
1253
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
1254
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
1255
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1256
+ denoising_end (`float`, *optional*):
1257
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1258
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1259
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
1260
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
1261
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
1262
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1263
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1264
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1265
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1266
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1267
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1268
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1269
+ usually at the expense of lower image quality.
1270
+ negative_prompt (`str` or `List[str]`, *optional*):
1271
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1272
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1273
+ less than `1`).
1274
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1275
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1276
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1277
+ prompt_embeds (`torch.Tensor`, *optional*):
1278
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1279
+ provided, text embeddings will be generated from `prompt` input argument.
1280
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1281
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1282
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1283
+ argument.
1284
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1285
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1286
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1287
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1288
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1289
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1290
+ input argument.
1291
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1292
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1293
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1294
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1295
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1296
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1297
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1298
+ The number of images to generate per prompt.
1299
+ eta (`float`, *optional*, defaults to 0.0):
1300
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1301
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1302
+ generator (`torch.Generator`, *optional*):
1303
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1304
+ to make generation deterministic.
1305
+ latents (`torch.Tensor`, *optional*):
1306
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1307
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1308
+ tensor will ge generated by sampling using the supplied random `generator`.
1309
+ output_type (`str`, *optional*, defaults to `"pil"`):
1310
+ The output format of the generate image. Choose between
1311
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1312
+ return_dict (`bool`, *optional*, defaults to `True`):
1313
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1314
+ plain tuple.
1315
+ cross_attention_kwargs (`dict`, *optional*):
1316
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1317
+ `self.processor` in
1318
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1319
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1320
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1321
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1322
+ explained in section 2.2 of
1323
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1324
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1325
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1326
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1327
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1328
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1329
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1330
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1331
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1332
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1333
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1334
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1335
+ micro-conditioning as explained in section 2.2 of
1336
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1337
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1338
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1339
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1340
+ micro-conditioning as explained in section 2.2 of
1341
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1342
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1343
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1344
+ To negatively condition the generation process based on a target image resolution. It should be as same
1345
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1346
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1347
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1348
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
1349
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1350
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1351
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1352
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1353
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1354
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1355
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
1356
+ clip_skip (`int`, *optional*):
1357
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1358
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1359
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1360
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1361
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1362
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1363
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1364
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1365
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1366
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1367
+ `._callback_tensor_inputs` attribute of your pipeline class.
1368
+
1369
+ Examples:
1370
+
1371
+ Returns:
1372
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
1373
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1374
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
1375
+ """
1376
+ height, width = self._default_height_width(height, width, adapter_image)
1377
+ device = self._execution_device
1378
+
1379
+ if isinstance(self.adapter, MultiAdapter):
1380
+ adapter_input = []
1381
+
1382
+ for one_image in adapter_image:
1383
+ one_image = _preprocess_adapter_image(one_image, height, width)
1384
+ one_image = one_image.to(device=device, dtype=self.adapter.dtype)
1385
+ adapter_input.append(one_image)
1386
+ else:
1387
+ adapter_input = _preprocess_adapter_image(adapter_image, height, width)
1388
+ adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
1389
+
1390
+ callback = kwargs.pop("callback", None)
1391
+ callback_steps = kwargs.pop("callback_steps", None)
1392
+
1393
+ if callback is not None:
1394
+ deprecate(
1395
+ "callback",
1396
+ "1.0.0",
1397
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1398
+ )
1399
+ if callback_steps is not None:
1400
+ deprecate(
1401
+ "callback_steps",
1402
+ "1.0.0",
1403
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1404
+ )
1405
+
1406
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1407
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1408
+
1409
+ # 0. Default height and width to unet
1410
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1411
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1412
+
1413
+ # 1. Check inputs
1414
+ self.check_inputs(
1415
+ prompt,
1416
+ prompt_2,
1417
+ image,
1418
+ mask_image,
1419
+ height,
1420
+ width,
1421
+ strength,
1422
+ callback_steps,
1423
+ output_type,
1424
+ negative_prompt,
1425
+ negative_prompt_2,
1426
+ prompt_embeds,
1427
+ negative_prompt_embeds,
1428
+ ip_adapter_image,
1429
+ ip_adapter_image_embeds,
1430
+ callback_on_step_end_tensor_inputs,
1431
+ padding_mask_crop,
1432
+ )
1433
+
1434
+ self._guidance_scale = guidance_scale
1435
+ self._guidance_rescale = guidance_rescale
1436
+ self._clip_skip = clip_skip
1437
+ self._cross_attention_kwargs = cross_attention_kwargs
1438
+ self._denoising_end = denoising_end
1439
+ self._denoising_start = denoising_start
1440
+ self._interrupt = False
1441
+
1442
+ # 2. Define call parameters
1443
+ if prompt is not None and isinstance(prompt, str):
1444
+ batch_size = 1
1445
+ elif prompt is not None and isinstance(prompt, list):
1446
+ batch_size = len(prompt)
1447
+ else:
1448
+ batch_size = prompt_embeds.shape[0]
1449
+
1450
+ device = self._execution_device
1451
+
1452
+ # 3. Encode input prompt
1453
+ text_encoder_lora_scale = (
1454
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1455
+ )
1456
+
1457
+ (
1458
+ prompt_embeds,
1459
+ negative_prompt_embeds,
1460
+ pooled_prompt_embeds,
1461
+ negative_pooled_prompt_embeds,
1462
+ ) = self.encode_prompt(
1463
+ prompt=prompt,
1464
+ prompt_2=prompt_2,
1465
+ device=device,
1466
+ num_images_per_prompt=num_images_per_prompt,
1467
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1468
+ negative_prompt=negative_prompt,
1469
+ negative_prompt_2=negative_prompt_2,
1470
+ prompt_embeds=prompt_embeds,
1471
+ negative_prompt_embeds=negative_prompt_embeds,
1472
+ pooled_prompt_embeds=pooled_prompt_embeds,
1473
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1474
+ lora_scale=text_encoder_lora_scale,
1475
+ clip_skip=self.clip_skip,
1476
+ )
1477
+
1478
+ # 4. set timesteps
1479
+ def denoising_value_valid(dnv):
1480
+ return isinstance(dnv, float) and 0 < dnv < 1
1481
+
1482
+ timesteps, num_inference_steps = retrieve_timesteps(
1483
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1484
+ )
1485
+ timesteps, num_inference_steps = self.get_timesteps(
1486
+ num_inference_steps,
1487
+ strength,
1488
+ device,
1489
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
1490
+ )
1491
+ # check that number of inference steps is not < 1 - as this doesn't make sense
1492
+ if num_inference_steps < 1:
1493
+ raise ValueError(
1494
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1495
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1496
+ )
1497
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1498
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1499
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1500
+ is_strength_max = strength == 1.0
1501
+
1502
+ # 5. Preprocess mask and image
1503
+ if padding_mask_crop is not None:
1504
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1505
+ resize_mode = "fill"
1506
+ else:
1507
+ crops_coords = None
1508
+ resize_mode = "default"
1509
+
1510
+ original_image = image
1511
+ init_image = self.image_processor.preprocess(
1512
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1513
+ )
1514
+ init_image = init_image.to(dtype=torch.float32)
1515
+
1516
+ mask = self.mask_processor.preprocess(
1517
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1518
+ )
1519
+
1520
+ if masked_image_latents is not None:
1521
+ masked_image = masked_image_latents
1522
+ elif init_image.shape[1] == 4:
1523
+ # if images are in latent space, we can't mask it
1524
+ masked_image = None
1525
+ else:
1526
+ masked_image = init_image * (mask < 0.5)
1527
+
1528
+ # 6. Prepare latent variables
1529
+ num_channels_latents = self.vae.config.latent_channels
1530
+ num_channels_unet = self.unet.config.in_channels
1531
+ return_image_latents = num_channels_unet == 4
1532
+
1533
+ add_noise = True if self.denoising_start is None else False
1534
+ latents_outputs = self.prepare_latents(
1535
+ batch_size * num_images_per_prompt,
1536
+ num_channels_latents,
1537
+ height,
1538
+ width,
1539
+ prompt_embeds.dtype,
1540
+ device,
1541
+ generator,
1542
+ latents,
1543
+ image=init_image,
1544
+ timestep=latent_timestep,
1545
+ is_strength_max=is_strength_max,
1546
+ add_noise=add_noise,
1547
+ return_noise=True,
1548
+ return_image_latents=return_image_latents,
1549
+ )
1550
+
1551
+ if return_image_latents:
1552
+ latents, noise, image_latents = latents_outputs
1553
+ else:
1554
+ latents, noise = latents_outputs
1555
+
1556
+ # 7. Prepare mask latent variables
1557
+ mask, masked_image_latents = self.prepare_mask_latents(
1558
+ mask,
1559
+ masked_image,
1560
+ batch_size * num_images_per_prompt,
1561
+ height,
1562
+ width,
1563
+ prompt_embeds.dtype,
1564
+ device,
1565
+ generator,
1566
+ self.do_classifier_free_guidance,
1567
+ )
1568
+
1569
+ # 8. Check that sizes of mask, masked image and latents match
1570
+ if num_channels_unet == 9:
1571
+ # default case for runwayml/stable-diffusion-inpainting
1572
+ num_channels_mask = mask.shape[1]
1573
+ num_channels_masked_image = masked_image_latents.shape[1]
1574
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1575
+ raise ValueError(
1576
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1577
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1578
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1579
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1580
+ " `pipeline.unet` or your `mask_image` or `image` input."
1581
+ )
1582
+ elif num_channels_unet != 4:
1583
+ raise ValueError(
1584
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1585
+ )
1586
+ # 8.1 Prepare extra step kwargs.
1587
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1588
+
1589
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1590
+ height, width = latents.shape[-2:]
1591
+ height = height * self.vae_scale_factor
1592
+ width = width * self.vae_scale_factor
1593
+
1594
+ original_size = original_size or (height, width)
1595
+ target_size = target_size or (height, width)
1596
+
1597
+ # 10. Prepare added time ids & embeddings
1598
+ if isinstance(self.adapter, MultiAdapter):
1599
+ adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
1600
+ for k, v in enumerate(adapter_state):
1601
+ adapter_state[k] = v
1602
+ else:
1603
+ adapter_state = self.adapter(adapter_input)
1604
+ for k, v in enumerate(adapter_state):
1605
+ adapter_state[k] = v * adapter_conditioning_scale
1606
+ if num_images_per_prompt > 1:
1607
+ for k, v in enumerate(adapter_state):
1608
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
1609
+ if self.do_classifier_free_guidance:
1610
+ for k, v in enumerate(adapter_state):
1611
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
1612
+
1613
+ if negative_original_size is None:
1614
+ negative_original_size = original_size
1615
+ if negative_target_size is None:
1616
+ negative_target_size = target_size
1617
+
1618
+ add_text_embeds = pooled_prompt_embeds
1619
+ if self.text_encoder_2 is None:
1620
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1621
+ else:
1622
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1623
+
1624
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1625
+ original_size,
1626
+ crops_coords_top_left,
1627
+ target_size,
1628
+ aesthetic_score,
1629
+ negative_aesthetic_score,
1630
+ negative_original_size,
1631
+ negative_crops_coords_top_left,
1632
+ negative_target_size,
1633
+ dtype=prompt_embeds.dtype,
1634
+ text_encoder_projection_dim=text_encoder_projection_dim,
1635
+ )
1636
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1637
+
1638
+ if self.do_classifier_free_guidance:
1639
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1640
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1641
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1642
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1643
+
1644
+ prompt_embeds = prompt_embeds.to(device)
1645
+ add_text_embeds = add_text_embeds.to(device)
1646
+ add_time_ids = add_time_ids.to(device)
1647
+
1648
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1649
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1650
+ ip_adapter_image,
1651
+ ip_adapter_image_embeds,
1652
+ device,
1653
+ batch_size * num_images_per_prompt,
1654
+ self.do_classifier_free_guidance,
1655
+ )
1656
+
1657
+ # 11. Denoising loop
1658
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1659
+
1660
+ if (
1661
+ self.denoising_end is not None
1662
+ and self.denoising_start is not None
1663
+ and denoising_value_valid(self.denoising_end)
1664
+ and denoising_value_valid(self.denoising_start)
1665
+ and self.denoising_start >= self.denoising_end
1666
+ ):
1667
+ raise ValueError(
1668
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
1669
+ + f" {self.denoising_end} when using type float."
1670
+ )
1671
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
1672
+ discrete_timestep_cutoff = int(
1673
+ round(
1674
+ self.scheduler.config.num_train_timesteps
1675
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1676
+ )
1677
+ )
1678
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1679
+ timesteps = timesteps[:num_inference_steps]
1680
+
1681
+ # 11.1 Optionally get Guidance Scale Embedding
1682
+ timestep_cond = None
1683
+ if self.unet.config.time_cond_proj_dim is not None:
1684
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1685
+ timestep_cond = self.get_guidance_scale_embedding(
1686
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1687
+ ).to(device=device, dtype=latents.dtype)
1688
+
1689
+ self._num_timesteps = len(timesteps)
1690
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1691
+ for i, t in enumerate(timesteps):
1692
+ if self.interrupt:
1693
+ continue
1694
+ # expand the latents if we are doing classifier free guidance
1695
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1696
+
1697
+ # concat latents, mask, masked_image_latents in the channel dimension
1698
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1699
+
1700
+ if num_channels_unet == 9:
1701
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1702
+
1703
+ # predict the noise residual
1704
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1705
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1706
+ added_cond_kwargs["image_embeds"] = image_embeds
1707
+
1708
+ if i < int(num_inference_steps * adapter_conditioning_factor):
1709
+ down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
1710
+ else:
1711
+ down_intrablock_additional_residuals = None
1712
+
1713
+ noise_pred = self.unet(
1714
+ latent_model_input,
1715
+ t,
1716
+ encoder_hidden_states=prompt_embeds,
1717
+ timestep_cond=timestep_cond,
1718
+ cross_attention_kwargs=self.cross_attention_kwargs,
1719
+ added_cond_kwargs=added_cond_kwargs,
1720
+ return_dict=False,
1721
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
1722
+ )[0]
1723
+
1724
+ # perform guidance
1725
+ if self.do_classifier_free_guidance:
1726
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1727
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1728
+
1729
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1730
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1731
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1732
+
1733
+ # compute the previous noisy sample x_t -> x_t-1
1734
+ latents_dtype = latents.dtype
1735
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1736
+ if latents.dtype != latents_dtype:
1737
+ if torch.backends.mps.is_available():
1738
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1739
+ latents = latents.to(latents_dtype)
1740
+
1741
+ if num_channels_unet == 4:
1742
+ init_latents_proper = image_latents
1743
+ if self.do_classifier_free_guidance:
1744
+ init_mask, _ = mask.chunk(2)
1745
+ else:
1746
+ init_mask = mask
1747
+
1748
+ if i < len(timesteps) - 1:
1749
+ noise_timestep = timesteps[i + 1]
1750
+ init_latents_proper = self.scheduler.add_noise(
1751
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1752
+ )
1753
+
1754
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1755
+
1756
+ if callback_on_step_end is not None:
1757
+ callback_kwargs = {}
1758
+ for k in callback_on_step_end_tensor_inputs:
1759
+ callback_kwargs[k] = locals()[k]
1760
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1761
+
1762
+ latents = callback_outputs.pop("latents", latents)
1763
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1764
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1765
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1766
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1767
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1768
+ )
1769
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1770
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1771
+ mask = callback_outputs.pop("mask", mask)
1772
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1773
+
1774
+ # call the callback, if provided
1775
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1776
+ progress_bar.update()
1777
+ if callback is not None and i % callback_steps == 0:
1778
+ step_idx = i // getattr(self.scheduler, "order", 1)
1779
+ callback(step_idx, t, latents)
1780
+
1781
+ if XLA_AVAILABLE:
1782
+ xm.mark_step()
1783
+
1784
+ if not output_type == "latent":
1785
+ # make sure the VAE is in float32 mode, as it overflows in float16
1786
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1787
+
1788
+ if needs_upcasting:
1789
+ self.upcast_vae()
1790
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1791
+ elif latents.dtype != self.vae.dtype:
1792
+ if torch.backends.mps.is_available():
1793
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1794
+ self.vae = self.vae.to(latents.dtype)
1795
+
1796
+ # unscale/denormalize the latents
1797
+ # denormalize with the mean and std if available and not None
1798
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1799
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1800
+ if has_latents_mean and has_latents_std:
1801
+ latents_mean = (
1802
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1803
+ )
1804
+ latents_std = (
1805
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1806
+ )
1807
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1808
+ else:
1809
+ latents = latents / self.vae.config.scaling_factor
1810
+
1811
+ image = self.vae.decode(latents, return_dict=False)[0]
1812
+
1813
+ # cast back to fp16 if needed
1814
+ if needs_upcasting:
1815
+ self.vae.to(dtype=torch.float16)
1816
+ else:
1817
+ return StableDiffusionXLPipelineOutput(images=latents)
1818
+
1819
+ # apply watermark if available
1820
+ if self.watermark is not None:
1821
+ image = self.watermark.apply_watermark(image)
1822
+
1823
+ image = self.image_processor.postprocess(image, output_type=output_type)
1824
+
1825
+ if padding_mask_crop is not None:
1826
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1827
+
1828
+ # Offload all models
1829
+ self.maybe_free_model_hooks()
1830
+
1831
+ if not return_dict:
1832
+ return (image,)
1833
+
1834
+ return StableDiffusionXLPipelineOutput(images=image)
segment_utils.py CHANGED
@@ -48,6 +48,31 @@ def segment_image(input_image, category, generate_size, mask_expansion, mask_dil
48
 
49
  return origin_area_image, croper
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def get_face_mask(category_mask_np, dilation=1):
52
  face_skin_mask = category_mask_np == 3
53
  if dilation > 0:
 
48
 
49
  return origin_area_image, croper
50
 
51
+ def segment_image_withmask(input_image, category, generate_size, mask_expansion, mask_dilation):
52
+ mask_size = int(generate_size)
53
+ mask_expansion = int(mask_expansion)
54
+
55
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
56
+ segmentation_result = segmenter.segment(image)
57
+ category_mask = segmentation_result.category_mask
58
+ category_mask_np = category_mask.numpy_view()
59
+
60
+ if category == "hair":
61
+ target_mask = get_hair_mask(category_mask_np, mask_dilation)
62
+ elif category == "clothes":
63
+ target_mask = get_clothes_mask(category_mask_np, mask_dilation)
64
+ elif category == "face":
65
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
66
+ else:
67
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
68
+
69
+ croper = Croper(input_image, target_mask, mask_size, mask_expansion)
70
+ croper.corp_mask_image()
71
+ origin_area_image = croper.resized_square_image
72
+ mask_image = croper.resized_square_mask_image
73
+
74
+ return origin_area_image, mask_image, croper
75
+
76
  def get_face_mask(category_mask_np, dilation=1):
77
  face_skin_mask = category_mask_np == 3
78
  if dilation > 0: