zhiweili commited on
Commit
ff0aba3
1 Parent(s): a823397

add app_masked

Browse files
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_ddim import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_masked import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_masked.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from segment_utils import(
8
+ segment_image_withmask,
9
+ restore_result,
10
+ )
11
+ from diffusers import (
12
+ DiffusionPipeline,
13
+ )
14
+
15
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
16
+
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ DEFAULT_EDIT_PROMPT = "a woman with linen-blonde-hair"
20
+ DEFAULT_NEGATIVE_PROMPT = "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, poorly drawn face, bad face, fused face, ugly face, worst face, asymmetrical, unrealistic skin texture, bad proportions, out of frame, poorly drawn hands, cloned face, double face"
21
+
22
+ DEFAULT_CATEGORY = "hair"
23
+
24
+ basepipeline = DiffusionPipeline.from_pretrained(
25
+ BASE_MODEL,
26
+ torch_dtype=torch.float16,
27
+ use_safetensors=True,
28
+ custom_pipeline="./pipelines/masked_stable_diffusion_xl_img2img.py",
29
+ )
30
+
31
+ basepipeline = basepipeline.to(DEVICE)
32
+
33
+ basepipeline.enable_xformers_memory_efficient_attention()
34
+
35
+ @spaces.GPU(duration=30)
36
+ def image_to_image(
37
+ input_image: Image,
38
+ mask_image: Image,
39
+ edit_prompt: str,
40
+ seed: int,
41
+ num_steps: int,
42
+ guidance_scale: float,
43
+ generate_size: int,
44
+ blur: int,
45
+ strength: float,
46
+ ):
47
+ run_task_time = 0
48
+ time_cost_str = ''
49
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
50
+
51
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
52
+ generated_image = basepipeline(
53
+ generator=generator,
54
+ prompt=edit_prompt,
55
+ negative_prompt=DEFAULT_NEGATIVE_PROMPT,
56
+ original_image=input_image,
57
+ mask=mask_image,
58
+ guidance_scale=guidance_scale,
59
+ num_inference_steps=num_steps,
60
+ blur=blur,
61
+ strength=strength,
62
+ ).images[0]
63
+
64
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
65
+
66
+ return generated_image, time_cost_str
67
+
68
+ def get_time_cost(run_task_time, time_cost_str):
69
+ now_time = int(time.time()*1000)
70
+ if run_task_time == 0:
71
+ time_cost_str = 'start'
72
+ else:
73
+ if time_cost_str != '':
74
+ time_cost_str += f'-->'
75
+ time_cost_str += f'{now_time - run_task_time}'
76
+ run_task_time = now_time
77
+ return run_task_time, time_cost_str
78
+
79
+ def create_demo() -> gr.Blocks:
80
+ with gr.Blocks() as demo:
81
+ croper = gr.State()
82
+ with gr.Row():
83
+ with gr.Column():
84
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
85
+ generate_size = gr.Number(label="Generate Size", value=512)
86
+ with gr.Column():
87
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
88
+ guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
89
+ with gr.Column():
90
+ with gr.Accordion("Advanced Options", open=False):
91
+ blur = gr.Slider(minimum=0, maximum=100, value=48, step=1, label="Blur")
92
+ strength = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Strength")
93
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
94
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
95
+ seed = gr.Number(label="Seed", value=8)
96
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
97
+ g_btn = gr.Button("Edit Image")
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ input_image = gr.Image(label="Input Image", type="pil")
102
+ with gr.Column():
103
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
104
+ with gr.Column():
105
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
106
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
107
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
108
+ mask_image = gr.Image(label="Mask Image", type="pil", interactive=False)
109
+
110
+ g_btn.click(
111
+ fn=segment_image_withmask,
112
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
113
+ outputs=[origin_area_image, mask_image, croper],
114
+ ).success(
115
+ fn=image_to_image,
116
+ inputs=[origin_area_image, mask_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, blur, strength],
117
+ outputs=[generated_image, generated_cost],
118
+ ).success(
119
+ fn=restore_result,
120
+ inputs=[croper, category, generated_image],
121
+ outputs=[restored_image],
122
+ )
123
+
124
+ return demo
pipelines/masked_stable_diffusion_xl_img2img.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image, ImageFilter
6
+
7
+ from diffusers.image_processor import PipelineImageInput
8
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
9
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import (
10
+ StableDiffusionXLImg2ImgPipeline,
11
+ rescale_noise_cfg,
12
+ retrieve_latents,
13
+ retrieve_timesteps,
14
+ )
15
+ from diffusers.utils import (
16
+ deprecate,
17
+ is_torch_xla_available,
18
+ logging,
19
+ )
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+
22
+
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+
26
+ XLA_AVAILABLE = True
27
+ else:
28
+ XLA_AVAILABLE = False
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ class MaskedStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
35
+ debug_save = 0
36
+
37
+ @torch.no_grad()
38
+ def __call__(
39
+ self,
40
+ prompt: Union[str, List[str]] = None,
41
+ prompt_2: Optional[Union[str, List[str]]] = None,
42
+ image: PipelineImageInput = None,
43
+ original_image: PipelineImageInput = None,
44
+ strength: float = 0.3,
45
+ num_inference_steps: Optional[int] = 50,
46
+ timesteps: List[int] = None,
47
+ denoising_start: Optional[float] = None,
48
+ denoising_end: Optional[float] = None,
49
+ guidance_scale: Optional[float] = 5.0,
50
+ negative_prompt: Optional[Union[str, List[str]]] = None,
51
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
52
+ num_images_per_prompt: Optional[int] = 1,
53
+ eta: Optional[float] = 0.0,
54
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
55
+ latents: Optional[torch.FloatTensor] = None,
56
+ prompt_embeds: Optional[torch.FloatTensor] = None,
57
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
58
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
59
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
60
+ ip_adapter_image: Optional[PipelineImageInput] = None,
61
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
62
+ output_type: Optional[str] = "pil",
63
+ return_dict: bool = True,
64
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
65
+ guidance_rescale: float = 0.0,
66
+ original_size: Tuple[int, int] = None,
67
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
68
+ target_size: Tuple[int, int] = None,
69
+ negative_original_size: Optional[Tuple[int, int]] = None,
70
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
71
+ negative_target_size: Optional[Tuple[int, int]] = None,
72
+ aesthetic_score: float = 6.0,
73
+ negative_aesthetic_score: float = 2.5,
74
+ clip_skip: Optional[int] = None,
75
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
76
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
77
+ mask: Union[
78
+ torch.FloatTensor,
79
+ Image.Image,
80
+ np.ndarray,
81
+ List[torch.FloatTensor],
82
+ List[Image.Image],
83
+ List[np.ndarray],
84
+ ] = None,
85
+ blur=24,
86
+ blur_compose=4,
87
+ sample_mode="sample",
88
+ **kwargs,
89
+ ):
90
+ r"""
91
+ The call function to the pipeline for generation.
92
+
93
+ Args:
94
+ prompt (`str` or `List[str]`, *optional*):
95
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
96
+ image (`PipelineImageInput`):
97
+ `Image` or tensor representing an image batch to be used as the starting point. This image might have mask painted on it.
98
+ original_image (`PipelineImageInput`, *optional*):
99
+ `Image` or tensor representing an image batch to be used for blending with the result.
100
+ strength (`float`, *optional*, defaults to 0.8):
101
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
102
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
103
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
104
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
105
+ essentially ignores `image`.
106
+ num_inference_steps (`int`, *optional*, defaults to 50):
107
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
108
+ expense of slower inference. This parameter is modulated by `strength`.
109
+ guidance_scale (`float`, *optional*, defaults to 7.5):
110
+ A higher guidance scale value encourages the model to generate images closely linked to the text
111
+ ,`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
112
+ negative_prompt (`str` or `List[str]`, *optional*):
113
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
114
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
115
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
116
+ The number of images to generate per prompt.
117
+ eta (`float`, *optional*, defaults to 0.0):
118
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
119
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
120
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
121
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
122
+ generation deterministic.
123
+ prompt_embeds (`torch.FloatTensor`, *optional*):
124
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
125
+ provided, text embeddings are generated from the `prompt` input argument.
126
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
127
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
128
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
129
+ output_type (`str`, *optional*, defaults to `"pil"`):
130
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
131
+ return_dict (`bool`, *optional*, defaults to `True`):
132
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
133
+ plain tuple.
134
+ callback (`Callable`, *optional*):
135
+ A function that calls every `callback_steps` steps during inference. The function is called with the
136
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
137
+ callback_steps (`int`, *optional*, defaults to 1):
138
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
139
+ every step.
140
+ cross_attention_kwargs (`dict`, *optional*):
141
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
142
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
143
+ blur (`int`, *optional*):
144
+ blur to apply to mask
145
+ blur_compose (`int`, *optional*):
146
+ blur to apply for composition of original a
147
+ mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*):
148
+ A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied.
149
+ sample_mode (`str`, *optional*):
150
+ control latents initialisation for the inpaint area, can be one of sample, argmax, random
151
+ Examples:
152
+
153
+ Returns:
154
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
155
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
156
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
157
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
158
+ "not-safe-for-work" (nsfw) content.
159
+ """
160
+ # code adapted from parent class StableDiffusionXLImg2ImgPipeline
161
+ callback = kwargs.pop("callback", None)
162
+ callback_steps = kwargs.pop("callback_steps", None)
163
+
164
+ if callback is not None:
165
+ deprecate(
166
+ "callback",
167
+ "1.0.0",
168
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
169
+ )
170
+ if callback_steps is not None:
171
+ deprecate(
172
+ "callback_steps",
173
+ "1.0.0",
174
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
175
+ )
176
+
177
+ # 0. Check inputs. Raise error if not correct
178
+ self.check_inputs(
179
+ prompt,
180
+ prompt_2,
181
+ strength,
182
+ num_inference_steps,
183
+ callback_steps,
184
+ negative_prompt,
185
+ negative_prompt_2,
186
+ prompt_embeds,
187
+ negative_prompt_embeds,
188
+ ip_adapter_image,
189
+ ip_adapter_image_embeds,
190
+ callback_on_step_end_tensor_inputs,
191
+ )
192
+
193
+ self._guidance_scale = guidance_scale
194
+ self._guidance_rescale = guidance_rescale
195
+ self._clip_skip = clip_skip
196
+ self._cross_attention_kwargs = cross_attention_kwargs
197
+ self._denoising_end = denoising_end
198
+ self._denoising_start = denoising_start
199
+ self._interrupt = False
200
+
201
+ # 1. Define call parameters
202
+ # mask is computed from difference between image and original_image
203
+ if image is not None:
204
+ neq = np.any(np.array(original_image) != np.array(image), axis=-1)
205
+ mask = neq.astype(np.uint8) * 255
206
+ else:
207
+ assert mask is not None
208
+
209
+ if not isinstance(mask, Image.Image):
210
+ pil_mask = Image.fromarray(mask)
211
+ if pil_mask.mode != "L":
212
+ pil_mask = pil_mask.convert("L")
213
+ mask_blur = self.blur_mask(pil_mask, blur)
214
+ mask_compose = self.blur_mask(pil_mask, blur_compose)
215
+ if original_image is None:
216
+ original_image = image
217
+ if prompt is not None and isinstance(prompt, str):
218
+ batch_size = 1
219
+ elif prompt is not None and isinstance(prompt, list):
220
+ batch_size = len(prompt)
221
+ else:
222
+ batch_size = prompt_embeds.shape[0]
223
+
224
+ device = self._execution_device
225
+
226
+ # 2. Encode input prompt
227
+ text_encoder_lora_scale = (
228
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
229
+ )
230
+ (
231
+ prompt_embeds,
232
+ negative_prompt_embeds,
233
+ pooled_prompt_embeds,
234
+ negative_pooled_prompt_embeds,
235
+ ) = self.encode_prompt(
236
+ prompt=prompt,
237
+ prompt_2=prompt_2,
238
+ device=device,
239
+ num_images_per_prompt=num_images_per_prompt,
240
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
241
+ negative_prompt=negative_prompt,
242
+ negative_prompt_2=negative_prompt_2,
243
+ prompt_embeds=prompt_embeds,
244
+ negative_prompt_embeds=negative_prompt_embeds,
245
+ pooled_prompt_embeds=pooled_prompt_embeds,
246
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
247
+ lora_scale=text_encoder_lora_scale,
248
+ clip_skip=self.clip_skip,
249
+ )
250
+
251
+ # 3. Preprocess image
252
+ input_image = image if image is not None else original_image
253
+ image = self.image_processor.preprocess(input_image)
254
+ original_image = self.image_processor.preprocess(original_image)
255
+
256
+ # 4. set timesteps
257
+ def denoising_value_valid(dnv):
258
+ return isinstance(dnv, float) and 0 < dnv < 1
259
+
260
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
261
+ timesteps, num_inference_steps = self.get_timesteps(
262
+ num_inference_steps,
263
+ strength,
264
+ device,
265
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
266
+ )
267
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
268
+
269
+ add_noise = True if self.denoising_start is None else False
270
+
271
+ # 5. Prepare latent variables
272
+ # It is sampled from the latent distribution of the VAE
273
+ # that's what we repaint
274
+ latents = self.prepare_latents(
275
+ image,
276
+ latent_timestep,
277
+ batch_size,
278
+ num_images_per_prompt,
279
+ prompt_embeds.dtype,
280
+ device,
281
+ generator,
282
+ add_noise,
283
+ sample_mode=sample_mode,
284
+ )
285
+
286
+ # mean of the latent distribution
287
+ # it is multiplied by self.vae.config.scaling_factor
288
+ non_paint_latents = self.prepare_latents(
289
+ original_image,
290
+ latent_timestep,
291
+ batch_size,
292
+ num_images_per_prompt,
293
+ prompt_embeds.dtype,
294
+ device,
295
+ generator,
296
+ add_noise=False,
297
+ sample_mode="argmax",
298
+ )
299
+
300
+ if self.debug_save:
301
+ init_img_from_latents = self.latents_to_img(non_paint_latents)
302
+ init_img_from_latents[0].save("non_paint_latents.png")
303
+ # 6. create latent mask
304
+ latent_mask = self._make_latent_mask(latents, mask)
305
+
306
+ # 7. Prepare extra step kwargs.
307
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
308
+
309
+ height, width = latents.shape[-2:]
310
+ height = height * self.vae_scale_factor
311
+ width = width * self.vae_scale_factor
312
+
313
+ original_size = original_size or (height, width)
314
+ target_size = target_size or (height, width)
315
+
316
+ # 8. Prepare added time ids & embeddings
317
+ if negative_original_size is None:
318
+ negative_original_size = original_size
319
+ if negative_target_size is None:
320
+ negative_target_size = target_size
321
+
322
+ add_text_embeds = pooled_prompt_embeds
323
+ if self.text_encoder_2 is None:
324
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
325
+ else:
326
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
327
+
328
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
329
+ original_size,
330
+ crops_coords_top_left,
331
+ target_size,
332
+ aesthetic_score,
333
+ negative_aesthetic_score,
334
+ negative_original_size,
335
+ negative_crops_coords_top_left,
336
+ negative_target_size,
337
+ dtype=prompt_embeds.dtype,
338
+ text_encoder_projection_dim=text_encoder_projection_dim,
339
+ )
340
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
341
+
342
+ if self.do_classifier_free_guidance:
343
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
344
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
345
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
346
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
347
+
348
+ prompt_embeds = prompt_embeds.to(device)
349
+ add_text_embeds = add_text_embeds.to(device)
350
+ add_time_ids = add_time_ids.to(device)
351
+
352
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
353
+ image_embeds = self.prepare_ip_adapter_image_embeds(
354
+ ip_adapter_image,
355
+ ip_adapter_image_embeds,
356
+ device,
357
+ batch_size * num_images_per_prompt,
358
+ self.do_classifier_free_guidance,
359
+ )
360
+
361
+ # 10. Denoising loop
362
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
363
+
364
+ # 10.1 Apply denoising_end
365
+ if (
366
+ self.denoising_end is not None
367
+ and self.denoising_start is not None
368
+ and denoising_value_valid(self.denoising_end)
369
+ and denoising_value_valid(self.denoising_start)
370
+ and self.denoising_start >= self.denoising_end
371
+ ):
372
+ raise ValueError(
373
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
374
+ + f" {self.denoising_end} when using type float."
375
+ )
376
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
377
+ discrete_timestep_cutoff = int(
378
+ round(
379
+ self.scheduler.config.num_train_timesteps
380
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
381
+ )
382
+ )
383
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
384
+ timesteps = timesteps[:num_inference_steps]
385
+
386
+ # 10.2 Optionally get Guidance Scale Embedding
387
+ timestep_cond = None
388
+ if self.unet.config.time_cond_proj_dim is not None:
389
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
390
+ timestep_cond = self.get_guidance_scale_embedding(
391
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
392
+ ).to(device=device, dtype=latents.dtype)
393
+
394
+ self._num_timesteps = len(timesteps)
395
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
396
+ for i, t in enumerate(timesteps):
397
+ if self.interrupt:
398
+ continue
399
+
400
+ shape = non_paint_latents.shape
401
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
402
+ # noisy latent code of input image at current step
403
+ orig_latents_t = non_paint_latents
404
+ orig_latents_t = self.scheduler.add_noise(non_paint_latents, noise, t.unsqueeze(0))
405
+
406
+ # orig_latents_t (1 - latent_mask) + latents * latent_mask
407
+ latents = torch.lerp(orig_latents_t, latents, latent_mask)
408
+
409
+ if self.debug_save:
410
+ img1 = self.latents_to_img(latents)
411
+ t_str = str(t.int().item())
412
+ for i in range(3 - len(t_str)):
413
+ t_str = "0" + t_str
414
+ img1[0].save(f"step{t_str}.png")
415
+
416
+ # expand the latents if we are doing classifier free guidance
417
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
418
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
419
+
420
+ # predict the noise residual
421
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
422
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
423
+ added_cond_kwargs["image_embeds"] = image_embeds
424
+
425
+ noise_pred = self.unet(
426
+ latent_model_input,
427
+ t,
428
+ encoder_hidden_states=prompt_embeds,
429
+ timestep_cond=timestep_cond,
430
+ cross_attention_kwargs=self.cross_attention_kwargs,
431
+ added_cond_kwargs=added_cond_kwargs,
432
+ return_dict=False,
433
+ )[0]
434
+
435
+ # perform guidance
436
+ if self.do_classifier_free_guidance:
437
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
438
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
439
+
440
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
441
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
442
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
443
+
444
+ # compute the previous noisy sample x_t -> x_t-1
445
+ latents_dtype = latents.dtype
446
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
447
+
448
+ if latents.dtype != latents_dtype:
449
+ if torch.backends.mps.is_available():
450
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
451
+ latents = latents.to(latents_dtype)
452
+
453
+ if callback_on_step_end is not None:
454
+ callback_kwargs = {}
455
+ for k in callback_on_step_end_tensor_inputs:
456
+ callback_kwargs[k] = locals()[k]
457
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
458
+
459
+ latents = callback_outputs.pop("latents", latents)
460
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
461
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
462
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
463
+ negative_pooled_prompt_embeds = callback_outputs.pop(
464
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
465
+ )
466
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
467
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
468
+
469
+ # call the callback, if provided
470
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
471
+ progress_bar.update()
472
+ if callback is not None and i % callback_steps == 0:
473
+ step_idx = i // getattr(self.scheduler, "order", 1)
474
+ callback(step_idx, t, latents)
475
+
476
+ if XLA_AVAILABLE:
477
+ xm.mark_step()
478
+
479
+ if not output_type == "latent":
480
+ # make sure the VAE is in float32 mode, as it overflows in float16
481
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
482
+
483
+ if needs_upcasting:
484
+ self.upcast_vae()
485
+ elif latents.dtype != self.vae.dtype:
486
+ if torch.backends.mps.is_available():
487
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
488
+ self.vae = self.vae.to(latents.dtype)
489
+
490
+ if self.debug_save:
491
+ image_gen = self.latents_to_img(latents)
492
+ image_gen[0].save("from_latent.png")
493
+
494
+ if latent_mask is not None:
495
+ # interpolate with latent mask
496
+ latents = torch.lerp(non_paint_latents, latents, latent_mask)
497
+
498
+ latents = self.denormalize(latents)
499
+ image = self.vae.decode(latents, return_dict=False)[0]
500
+ m = mask_compose.permute(2, 0, 1).unsqueeze(0).to(image)
501
+ img_compose = m * image + (1 - m) * original_image.to(image)
502
+ image = img_compose
503
+ # cast back to fp16 if needed
504
+ if needs_upcasting:
505
+ self.vae.to(dtype=torch.float16)
506
+ else:
507
+ image = latents
508
+
509
+ # apply watermark if available
510
+ if self.watermark is not None:
511
+ image = self.watermark.apply_watermark(image)
512
+
513
+ image = self.image_processor.postprocess(image, output_type=output_type)
514
+
515
+ # Offload all models
516
+ self.maybe_free_model_hooks()
517
+
518
+ if not return_dict:
519
+ return (image,)
520
+
521
+ return StableDiffusionXLPipelineOutput(images=image)
522
+
523
+ def _make_latent_mask(self, latents, mask):
524
+ if mask is not None:
525
+ latent_mask = []
526
+ if not isinstance(mask, list):
527
+ tmp_mask = [mask]
528
+ else:
529
+ tmp_mask = mask
530
+ _, l_channels, l_height, l_width = latents.shape
531
+ for m in tmp_mask:
532
+ if not isinstance(m, Image.Image):
533
+ if len(m.shape) == 2:
534
+ m = m[..., np.newaxis]
535
+ if m.max() > 1:
536
+ m = m / 255.0
537
+ m = self.image_processor.numpy_to_pil(m)[0]
538
+ if m.mode != "L":
539
+ m = m.convert("L")
540
+ resized = self.image_processor.resize(m, l_height, l_width)
541
+ if self.debug_save:
542
+ resized.save("latent_mask.png")
543
+ latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0))
544
+ latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents)
545
+ latent_mask = latent_mask / max(latent_mask.max(), 1)
546
+ return latent_mask
547
+
548
+ def prepare_latents(
549
+ self,
550
+ image,
551
+ timestep,
552
+ batch_size,
553
+ num_images_per_prompt,
554
+ dtype,
555
+ device,
556
+ generator=None,
557
+ add_noise=True,
558
+ sample_mode: str = "sample",
559
+ ):
560
+ if not isinstance(image, (torch.Tensor, Image.Image, list)):
561
+ raise ValueError(
562
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
563
+ )
564
+
565
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
566
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
567
+ self.text_encoder_2.to("cpu")
568
+ torch.cuda.empty_cache()
569
+
570
+ image = image.to(device=device, dtype=dtype)
571
+
572
+ batch_size = batch_size * num_images_per_prompt
573
+
574
+ if image.shape[1] == 4:
575
+ init_latents = image
576
+ elif sample_mode == "random":
577
+ height, width = image.shape[-2:]
578
+ num_channels_latents = self.unet.config.in_channels
579
+ latents = self.random_latents(
580
+ batch_size,
581
+ num_channels_latents,
582
+ height,
583
+ width,
584
+ dtype,
585
+ device,
586
+ generator,
587
+ )
588
+ return self.vae.config.scaling_factor * latents
589
+ else:
590
+ # make sure the VAE is in float32 mode, as it overflows in float16
591
+ if self.vae.config.force_upcast:
592
+ image = image.float()
593
+ self.vae.to(dtype=torch.float32)
594
+
595
+ if isinstance(generator, list) and len(generator) != batch_size:
596
+ raise ValueError(
597
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
598
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
599
+ )
600
+
601
+ elif isinstance(generator, list):
602
+ init_latents = [
603
+ retrieve_latents(
604
+ self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode
605
+ )
606
+ for i in range(batch_size)
607
+ ]
608
+ init_latents = torch.cat(init_latents, dim=0)
609
+ else:
610
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode=sample_mode)
611
+
612
+ if self.vae.config.force_upcast:
613
+ self.vae.to(dtype)
614
+
615
+ init_latents = init_latents.to(dtype)
616
+ init_latents = self.vae.config.scaling_factor * init_latents
617
+
618
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
619
+ # expand init_latents for batch_size
620
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
621
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
622
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
623
+ raise ValueError(
624
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
625
+ )
626
+ else:
627
+ init_latents = torch.cat([init_latents], dim=0)
628
+
629
+ if add_noise:
630
+ shape = init_latents.shape
631
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
632
+ # get latents
633
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
634
+
635
+ latents = init_latents
636
+
637
+ return latents
638
+
639
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
640
+ def random_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
641
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
642
+ if isinstance(generator, list) and len(generator) != batch_size:
643
+ raise ValueError(
644
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
645
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
646
+ )
647
+
648
+ if latents is None:
649
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
650
+ else:
651
+ latents = latents.to(device)
652
+
653
+ # scale the initial noise by the standard deviation required by the scheduler
654
+ latents = latents * self.scheduler.init_noise_sigma
655
+ return latents
656
+
657
+ def denormalize(self, latents):
658
+ # unscale/denormalize the latents
659
+ # denormalize with the mean and std if available and not None
660
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
661
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
662
+ if has_latents_mean and has_latents_std:
663
+ latents_mean = (
664
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
665
+ )
666
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
667
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
668
+ else:
669
+ latents = latents / self.vae.config.scaling_factor
670
+
671
+ return latents
672
+
673
+ def latents_to_img(self, latents):
674
+ l1 = self.denormalize(latents)
675
+ img1 = self.vae.decode(l1, return_dict=False)[0]
676
+ img1 = self.image_processor.postprocess(img1, output_type="pil", do_denormalize=[True])
677
+ return img1
678
+
679
+ def blur_mask(self, pil_mask, blur):
680
+ mask_blur = pil_mask.filter(ImageFilter.GaussianBlur(radius=blur))
681
+ mask_blur = np.array(mask_blur)
682
+ return torch.from_numpy(np.tile(mask_blur / mask_blur.max(), (3, 1, 1)).transpose(1, 2, 0))