Spaces:
Running
on
Zero
Running
on
Zero
zhiweili
commited on
Commit
•
ff0aba3
1
Parent(s):
a823397
add app_masked
Browse files- app.py +1 -1
- app_masked.py +124 -0
- pipelines/masked_stable_diffusion_xl_img2img.py +682 -0
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
from
|
4 |
|
5 |
with gr.Blocks(css="style.css") as demo:
|
6 |
with gr.Tabs():
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from app_masked import create_demo as create_demo_haircolor
|
4 |
|
5 |
with gr.Blocks(css="style.css") as demo:
|
6 |
with gr.Tabs():
|
app_masked.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
from segment_utils import(
|
8 |
+
segment_image_withmask,
|
9 |
+
restore_result,
|
10 |
+
)
|
11 |
+
from diffusers import (
|
12 |
+
DiffusionPipeline,
|
13 |
+
)
|
14 |
+
|
15 |
+
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
|
16 |
+
|
17 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
DEFAULT_EDIT_PROMPT = "a woman with linen-blonde-hair"
|
20 |
+
DEFAULT_NEGATIVE_PROMPT = "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, poorly drawn face, bad face, fused face, ugly face, worst face, asymmetrical, unrealistic skin texture, bad proportions, out of frame, poorly drawn hands, cloned face, double face"
|
21 |
+
|
22 |
+
DEFAULT_CATEGORY = "hair"
|
23 |
+
|
24 |
+
basepipeline = DiffusionPipeline.from_pretrained(
|
25 |
+
BASE_MODEL,
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
use_safetensors=True,
|
28 |
+
custom_pipeline="./pipelines/masked_stable_diffusion_xl_img2img.py",
|
29 |
+
)
|
30 |
+
|
31 |
+
basepipeline = basepipeline.to(DEVICE)
|
32 |
+
|
33 |
+
basepipeline.enable_xformers_memory_efficient_attention()
|
34 |
+
|
35 |
+
@spaces.GPU(duration=30)
|
36 |
+
def image_to_image(
|
37 |
+
input_image: Image,
|
38 |
+
mask_image: Image,
|
39 |
+
edit_prompt: str,
|
40 |
+
seed: int,
|
41 |
+
num_steps: int,
|
42 |
+
guidance_scale: float,
|
43 |
+
generate_size: int,
|
44 |
+
blur: int,
|
45 |
+
strength: float,
|
46 |
+
):
|
47 |
+
run_task_time = 0
|
48 |
+
time_cost_str = ''
|
49 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
50 |
+
|
51 |
+
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
52 |
+
generated_image = basepipeline(
|
53 |
+
generator=generator,
|
54 |
+
prompt=edit_prompt,
|
55 |
+
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
56 |
+
original_image=input_image,
|
57 |
+
mask=mask_image,
|
58 |
+
guidance_scale=guidance_scale,
|
59 |
+
num_inference_steps=num_steps,
|
60 |
+
blur=blur,
|
61 |
+
strength=strength,
|
62 |
+
).images[0]
|
63 |
+
|
64 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
65 |
+
|
66 |
+
return generated_image, time_cost_str
|
67 |
+
|
68 |
+
def get_time_cost(run_task_time, time_cost_str):
|
69 |
+
now_time = int(time.time()*1000)
|
70 |
+
if run_task_time == 0:
|
71 |
+
time_cost_str = 'start'
|
72 |
+
else:
|
73 |
+
if time_cost_str != '':
|
74 |
+
time_cost_str += f'-->'
|
75 |
+
time_cost_str += f'{now_time - run_task_time}'
|
76 |
+
run_task_time = now_time
|
77 |
+
return run_task_time, time_cost_str
|
78 |
+
|
79 |
+
def create_demo() -> gr.Blocks:
|
80 |
+
with gr.Blocks() as demo:
|
81 |
+
croper = gr.State()
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column():
|
84 |
+
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
|
85 |
+
generate_size = gr.Number(label="Generate Size", value=512)
|
86 |
+
with gr.Column():
|
87 |
+
num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
|
88 |
+
guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
|
89 |
+
with gr.Column():
|
90 |
+
with gr.Accordion("Advanced Options", open=False):
|
91 |
+
blur = gr.Slider(minimum=0, maximum=100, value=48, step=1, label="Blur")
|
92 |
+
strength = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Strength")
|
93 |
+
mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
|
94 |
+
mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
|
95 |
+
seed = gr.Number(label="Seed", value=8)
|
96 |
+
category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
|
97 |
+
g_btn = gr.Button("Edit Image")
|
98 |
+
|
99 |
+
with gr.Row():
|
100 |
+
with gr.Column():
|
101 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
102 |
+
with gr.Column():
|
103 |
+
restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
|
104 |
+
with gr.Column():
|
105 |
+
origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
|
106 |
+
generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
|
107 |
+
generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
|
108 |
+
mask_image = gr.Image(label="Mask Image", type="pil", interactive=False)
|
109 |
+
|
110 |
+
g_btn.click(
|
111 |
+
fn=segment_image_withmask,
|
112 |
+
inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
|
113 |
+
outputs=[origin_area_image, mask_image, croper],
|
114 |
+
).success(
|
115 |
+
fn=image_to_image,
|
116 |
+
inputs=[origin_area_image, mask_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, blur, strength],
|
117 |
+
outputs=[generated_image, generated_cost],
|
118 |
+
).success(
|
119 |
+
fn=restore_result,
|
120 |
+
inputs=[croper, category, generated_image],
|
121 |
+
outputs=[restored_image],
|
122 |
+
)
|
123 |
+
|
124 |
+
return demo
|
pipelines/masked_stable_diffusion_xl_img2img.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image, ImageFilter
|
6 |
+
|
7 |
+
from diffusers.image_processor import PipelineImageInput
|
8 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
9 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import (
|
10 |
+
StableDiffusionXLImg2ImgPipeline,
|
11 |
+
rescale_noise_cfg,
|
12 |
+
retrieve_latents,
|
13 |
+
retrieve_timesteps,
|
14 |
+
)
|
15 |
+
from diffusers.utils import (
|
16 |
+
deprecate,
|
17 |
+
is_torch_xla_available,
|
18 |
+
logging,
|
19 |
+
)
|
20 |
+
from diffusers.utils.torch_utils import randn_tensor
|
21 |
+
|
22 |
+
|
23 |
+
if is_torch_xla_available():
|
24 |
+
import torch_xla.core.xla_model as xm
|
25 |
+
|
26 |
+
XLA_AVAILABLE = True
|
27 |
+
else:
|
28 |
+
XLA_AVAILABLE = False
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
|
34 |
+
class MaskedStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
|
35 |
+
debug_save = 0
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def __call__(
|
39 |
+
self,
|
40 |
+
prompt: Union[str, List[str]] = None,
|
41 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
42 |
+
image: PipelineImageInput = None,
|
43 |
+
original_image: PipelineImageInput = None,
|
44 |
+
strength: float = 0.3,
|
45 |
+
num_inference_steps: Optional[int] = 50,
|
46 |
+
timesteps: List[int] = None,
|
47 |
+
denoising_start: Optional[float] = None,
|
48 |
+
denoising_end: Optional[float] = None,
|
49 |
+
guidance_scale: Optional[float] = 5.0,
|
50 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
51 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
52 |
+
num_images_per_prompt: Optional[int] = 1,
|
53 |
+
eta: Optional[float] = 0.0,
|
54 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
55 |
+
latents: Optional[torch.FloatTensor] = None,
|
56 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
57 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
58 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
59 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
60 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
61 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
62 |
+
output_type: Optional[str] = "pil",
|
63 |
+
return_dict: bool = True,
|
64 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
65 |
+
guidance_rescale: float = 0.0,
|
66 |
+
original_size: Tuple[int, int] = None,
|
67 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
68 |
+
target_size: Tuple[int, int] = None,
|
69 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
70 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
71 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
72 |
+
aesthetic_score: float = 6.0,
|
73 |
+
negative_aesthetic_score: float = 2.5,
|
74 |
+
clip_skip: Optional[int] = None,
|
75 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
76 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
77 |
+
mask: Union[
|
78 |
+
torch.FloatTensor,
|
79 |
+
Image.Image,
|
80 |
+
np.ndarray,
|
81 |
+
List[torch.FloatTensor],
|
82 |
+
List[Image.Image],
|
83 |
+
List[np.ndarray],
|
84 |
+
] = None,
|
85 |
+
blur=24,
|
86 |
+
blur_compose=4,
|
87 |
+
sample_mode="sample",
|
88 |
+
**kwargs,
|
89 |
+
):
|
90 |
+
r"""
|
91 |
+
The call function to the pipeline for generation.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
prompt (`str` or `List[str]`, *optional*):
|
95 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
96 |
+
image (`PipelineImageInput`):
|
97 |
+
`Image` or tensor representing an image batch to be used as the starting point. This image might have mask painted on it.
|
98 |
+
original_image (`PipelineImageInput`, *optional*):
|
99 |
+
`Image` or tensor representing an image batch to be used for blending with the result.
|
100 |
+
strength (`float`, *optional*, defaults to 0.8):
|
101 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
102 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
103 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
104 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
105 |
+
essentially ignores `image`.
|
106 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
107 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
108 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
109 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
110 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
111 |
+
,`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
112 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
113 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
114 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
115 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
116 |
+
The number of images to generate per prompt.
|
117 |
+
eta (`float`, *optional*, defaults to 0.0):
|
118 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
119 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
120 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
121 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
122 |
+
generation deterministic.
|
123 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
124 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
125 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
126 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
127 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
128 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
129 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
130 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
131 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
132 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
133 |
+
plain tuple.
|
134 |
+
callback (`Callable`, *optional*):
|
135 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
136 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
137 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
138 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
139 |
+
every step.
|
140 |
+
cross_attention_kwargs (`dict`, *optional*):
|
141 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
142 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
143 |
+
blur (`int`, *optional*):
|
144 |
+
blur to apply to mask
|
145 |
+
blur_compose (`int`, *optional*):
|
146 |
+
blur to apply for composition of original a
|
147 |
+
mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*):
|
148 |
+
A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied.
|
149 |
+
sample_mode (`str`, *optional*):
|
150 |
+
control latents initialisation for the inpaint area, can be one of sample, argmax, random
|
151 |
+
Examples:
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
155 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
156 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
157 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
158 |
+
"not-safe-for-work" (nsfw) content.
|
159 |
+
"""
|
160 |
+
# code adapted from parent class StableDiffusionXLImg2ImgPipeline
|
161 |
+
callback = kwargs.pop("callback", None)
|
162 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
163 |
+
|
164 |
+
if callback is not None:
|
165 |
+
deprecate(
|
166 |
+
"callback",
|
167 |
+
"1.0.0",
|
168 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
169 |
+
)
|
170 |
+
if callback_steps is not None:
|
171 |
+
deprecate(
|
172 |
+
"callback_steps",
|
173 |
+
"1.0.0",
|
174 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
175 |
+
)
|
176 |
+
|
177 |
+
# 0. Check inputs. Raise error if not correct
|
178 |
+
self.check_inputs(
|
179 |
+
prompt,
|
180 |
+
prompt_2,
|
181 |
+
strength,
|
182 |
+
num_inference_steps,
|
183 |
+
callback_steps,
|
184 |
+
negative_prompt,
|
185 |
+
negative_prompt_2,
|
186 |
+
prompt_embeds,
|
187 |
+
negative_prompt_embeds,
|
188 |
+
ip_adapter_image,
|
189 |
+
ip_adapter_image_embeds,
|
190 |
+
callback_on_step_end_tensor_inputs,
|
191 |
+
)
|
192 |
+
|
193 |
+
self._guidance_scale = guidance_scale
|
194 |
+
self._guidance_rescale = guidance_rescale
|
195 |
+
self._clip_skip = clip_skip
|
196 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
197 |
+
self._denoising_end = denoising_end
|
198 |
+
self._denoising_start = denoising_start
|
199 |
+
self._interrupt = False
|
200 |
+
|
201 |
+
# 1. Define call parameters
|
202 |
+
# mask is computed from difference between image and original_image
|
203 |
+
if image is not None:
|
204 |
+
neq = np.any(np.array(original_image) != np.array(image), axis=-1)
|
205 |
+
mask = neq.astype(np.uint8) * 255
|
206 |
+
else:
|
207 |
+
assert mask is not None
|
208 |
+
|
209 |
+
if not isinstance(mask, Image.Image):
|
210 |
+
pil_mask = Image.fromarray(mask)
|
211 |
+
if pil_mask.mode != "L":
|
212 |
+
pil_mask = pil_mask.convert("L")
|
213 |
+
mask_blur = self.blur_mask(pil_mask, blur)
|
214 |
+
mask_compose = self.blur_mask(pil_mask, blur_compose)
|
215 |
+
if original_image is None:
|
216 |
+
original_image = image
|
217 |
+
if prompt is not None and isinstance(prompt, str):
|
218 |
+
batch_size = 1
|
219 |
+
elif prompt is not None and isinstance(prompt, list):
|
220 |
+
batch_size = len(prompt)
|
221 |
+
else:
|
222 |
+
batch_size = prompt_embeds.shape[0]
|
223 |
+
|
224 |
+
device = self._execution_device
|
225 |
+
|
226 |
+
# 2. Encode input prompt
|
227 |
+
text_encoder_lora_scale = (
|
228 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
229 |
+
)
|
230 |
+
(
|
231 |
+
prompt_embeds,
|
232 |
+
negative_prompt_embeds,
|
233 |
+
pooled_prompt_embeds,
|
234 |
+
negative_pooled_prompt_embeds,
|
235 |
+
) = self.encode_prompt(
|
236 |
+
prompt=prompt,
|
237 |
+
prompt_2=prompt_2,
|
238 |
+
device=device,
|
239 |
+
num_images_per_prompt=num_images_per_prompt,
|
240 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
241 |
+
negative_prompt=negative_prompt,
|
242 |
+
negative_prompt_2=negative_prompt_2,
|
243 |
+
prompt_embeds=prompt_embeds,
|
244 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
245 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
246 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
247 |
+
lora_scale=text_encoder_lora_scale,
|
248 |
+
clip_skip=self.clip_skip,
|
249 |
+
)
|
250 |
+
|
251 |
+
# 3. Preprocess image
|
252 |
+
input_image = image if image is not None else original_image
|
253 |
+
image = self.image_processor.preprocess(input_image)
|
254 |
+
original_image = self.image_processor.preprocess(original_image)
|
255 |
+
|
256 |
+
# 4. set timesteps
|
257 |
+
def denoising_value_valid(dnv):
|
258 |
+
return isinstance(dnv, float) and 0 < dnv < 1
|
259 |
+
|
260 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
261 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
262 |
+
num_inference_steps,
|
263 |
+
strength,
|
264 |
+
device,
|
265 |
+
denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
|
266 |
+
)
|
267 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
268 |
+
|
269 |
+
add_noise = True if self.denoising_start is None else False
|
270 |
+
|
271 |
+
# 5. Prepare latent variables
|
272 |
+
# It is sampled from the latent distribution of the VAE
|
273 |
+
# that's what we repaint
|
274 |
+
latents = self.prepare_latents(
|
275 |
+
image,
|
276 |
+
latent_timestep,
|
277 |
+
batch_size,
|
278 |
+
num_images_per_prompt,
|
279 |
+
prompt_embeds.dtype,
|
280 |
+
device,
|
281 |
+
generator,
|
282 |
+
add_noise,
|
283 |
+
sample_mode=sample_mode,
|
284 |
+
)
|
285 |
+
|
286 |
+
# mean of the latent distribution
|
287 |
+
# it is multiplied by self.vae.config.scaling_factor
|
288 |
+
non_paint_latents = self.prepare_latents(
|
289 |
+
original_image,
|
290 |
+
latent_timestep,
|
291 |
+
batch_size,
|
292 |
+
num_images_per_prompt,
|
293 |
+
prompt_embeds.dtype,
|
294 |
+
device,
|
295 |
+
generator,
|
296 |
+
add_noise=False,
|
297 |
+
sample_mode="argmax",
|
298 |
+
)
|
299 |
+
|
300 |
+
if self.debug_save:
|
301 |
+
init_img_from_latents = self.latents_to_img(non_paint_latents)
|
302 |
+
init_img_from_latents[0].save("non_paint_latents.png")
|
303 |
+
# 6. create latent mask
|
304 |
+
latent_mask = self._make_latent_mask(latents, mask)
|
305 |
+
|
306 |
+
# 7. Prepare extra step kwargs.
|
307 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
308 |
+
|
309 |
+
height, width = latents.shape[-2:]
|
310 |
+
height = height * self.vae_scale_factor
|
311 |
+
width = width * self.vae_scale_factor
|
312 |
+
|
313 |
+
original_size = original_size or (height, width)
|
314 |
+
target_size = target_size or (height, width)
|
315 |
+
|
316 |
+
# 8. Prepare added time ids & embeddings
|
317 |
+
if negative_original_size is None:
|
318 |
+
negative_original_size = original_size
|
319 |
+
if negative_target_size is None:
|
320 |
+
negative_target_size = target_size
|
321 |
+
|
322 |
+
add_text_embeds = pooled_prompt_embeds
|
323 |
+
if self.text_encoder_2 is None:
|
324 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
325 |
+
else:
|
326 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
327 |
+
|
328 |
+
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
329 |
+
original_size,
|
330 |
+
crops_coords_top_left,
|
331 |
+
target_size,
|
332 |
+
aesthetic_score,
|
333 |
+
negative_aesthetic_score,
|
334 |
+
negative_original_size,
|
335 |
+
negative_crops_coords_top_left,
|
336 |
+
negative_target_size,
|
337 |
+
dtype=prompt_embeds.dtype,
|
338 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
339 |
+
)
|
340 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
341 |
+
|
342 |
+
if self.do_classifier_free_guidance:
|
343 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
344 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
345 |
+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
346 |
+
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
347 |
+
|
348 |
+
prompt_embeds = prompt_embeds.to(device)
|
349 |
+
add_text_embeds = add_text_embeds.to(device)
|
350 |
+
add_time_ids = add_time_ids.to(device)
|
351 |
+
|
352 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
353 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
354 |
+
ip_adapter_image,
|
355 |
+
ip_adapter_image_embeds,
|
356 |
+
device,
|
357 |
+
batch_size * num_images_per_prompt,
|
358 |
+
self.do_classifier_free_guidance,
|
359 |
+
)
|
360 |
+
|
361 |
+
# 10. Denoising loop
|
362 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
363 |
+
|
364 |
+
# 10.1 Apply denoising_end
|
365 |
+
if (
|
366 |
+
self.denoising_end is not None
|
367 |
+
and self.denoising_start is not None
|
368 |
+
and denoising_value_valid(self.denoising_end)
|
369 |
+
and denoising_value_valid(self.denoising_start)
|
370 |
+
and self.denoising_start >= self.denoising_end
|
371 |
+
):
|
372 |
+
raise ValueError(
|
373 |
+
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
|
374 |
+
+ f" {self.denoising_end} when using type float."
|
375 |
+
)
|
376 |
+
elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
|
377 |
+
discrete_timestep_cutoff = int(
|
378 |
+
round(
|
379 |
+
self.scheduler.config.num_train_timesteps
|
380 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
381 |
+
)
|
382 |
+
)
|
383 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
384 |
+
timesteps = timesteps[:num_inference_steps]
|
385 |
+
|
386 |
+
# 10.2 Optionally get Guidance Scale Embedding
|
387 |
+
timestep_cond = None
|
388 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
389 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
390 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
391 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
392 |
+
).to(device=device, dtype=latents.dtype)
|
393 |
+
|
394 |
+
self._num_timesteps = len(timesteps)
|
395 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
396 |
+
for i, t in enumerate(timesteps):
|
397 |
+
if self.interrupt:
|
398 |
+
continue
|
399 |
+
|
400 |
+
shape = non_paint_latents.shape
|
401 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
|
402 |
+
# noisy latent code of input image at current step
|
403 |
+
orig_latents_t = non_paint_latents
|
404 |
+
orig_latents_t = self.scheduler.add_noise(non_paint_latents, noise, t.unsqueeze(0))
|
405 |
+
|
406 |
+
# orig_latents_t (1 - latent_mask) + latents * latent_mask
|
407 |
+
latents = torch.lerp(orig_latents_t, latents, latent_mask)
|
408 |
+
|
409 |
+
if self.debug_save:
|
410 |
+
img1 = self.latents_to_img(latents)
|
411 |
+
t_str = str(t.int().item())
|
412 |
+
for i in range(3 - len(t_str)):
|
413 |
+
t_str = "0" + t_str
|
414 |
+
img1[0].save(f"step{t_str}.png")
|
415 |
+
|
416 |
+
# expand the latents if we are doing classifier free guidance
|
417 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
418 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
419 |
+
|
420 |
+
# predict the noise residual
|
421 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
422 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
423 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
424 |
+
|
425 |
+
noise_pred = self.unet(
|
426 |
+
latent_model_input,
|
427 |
+
t,
|
428 |
+
encoder_hidden_states=prompt_embeds,
|
429 |
+
timestep_cond=timestep_cond,
|
430 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
431 |
+
added_cond_kwargs=added_cond_kwargs,
|
432 |
+
return_dict=False,
|
433 |
+
)[0]
|
434 |
+
|
435 |
+
# perform guidance
|
436 |
+
if self.do_classifier_free_guidance:
|
437 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
438 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
439 |
+
|
440 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
441 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
442 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
443 |
+
|
444 |
+
# compute the previous noisy sample x_t -> x_t-1
|
445 |
+
latents_dtype = latents.dtype
|
446 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
447 |
+
|
448 |
+
if latents.dtype != latents_dtype:
|
449 |
+
if torch.backends.mps.is_available():
|
450 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
451 |
+
latents = latents.to(latents_dtype)
|
452 |
+
|
453 |
+
if callback_on_step_end is not None:
|
454 |
+
callback_kwargs = {}
|
455 |
+
for k in callback_on_step_end_tensor_inputs:
|
456 |
+
callback_kwargs[k] = locals()[k]
|
457 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
458 |
+
|
459 |
+
latents = callback_outputs.pop("latents", latents)
|
460 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
461 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
462 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
463 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
464 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
465 |
+
)
|
466 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
467 |
+
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
|
468 |
+
|
469 |
+
# call the callback, if provided
|
470 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
471 |
+
progress_bar.update()
|
472 |
+
if callback is not None and i % callback_steps == 0:
|
473 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
474 |
+
callback(step_idx, t, latents)
|
475 |
+
|
476 |
+
if XLA_AVAILABLE:
|
477 |
+
xm.mark_step()
|
478 |
+
|
479 |
+
if not output_type == "latent":
|
480 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
481 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
482 |
+
|
483 |
+
if needs_upcasting:
|
484 |
+
self.upcast_vae()
|
485 |
+
elif latents.dtype != self.vae.dtype:
|
486 |
+
if torch.backends.mps.is_available():
|
487 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
488 |
+
self.vae = self.vae.to(latents.dtype)
|
489 |
+
|
490 |
+
if self.debug_save:
|
491 |
+
image_gen = self.latents_to_img(latents)
|
492 |
+
image_gen[0].save("from_latent.png")
|
493 |
+
|
494 |
+
if latent_mask is not None:
|
495 |
+
# interpolate with latent mask
|
496 |
+
latents = torch.lerp(non_paint_latents, latents, latent_mask)
|
497 |
+
|
498 |
+
latents = self.denormalize(latents)
|
499 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
500 |
+
m = mask_compose.permute(2, 0, 1).unsqueeze(0).to(image)
|
501 |
+
img_compose = m * image + (1 - m) * original_image.to(image)
|
502 |
+
image = img_compose
|
503 |
+
# cast back to fp16 if needed
|
504 |
+
if needs_upcasting:
|
505 |
+
self.vae.to(dtype=torch.float16)
|
506 |
+
else:
|
507 |
+
image = latents
|
508 |
+
|
509 |
+
# apply watermark if available
|
510 |
+
if self.watermark is not None:
|
511 |
+
image = self.watermark.apply_watermark(image)
|
512 |
+
|
513 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
514 |
+
|
515 |
+
# Offload all models
|
516 |
+
self.maybe_free_model_hooks()
|
517 |
+
|
518 |
+
if not return_dict:
|
519 |
+
return (image,)
|
520 |
+
|
521 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
522 |
+
|
523 |
+
def _make_latent_mask(self, latents, mask):
|
524 |
+
if mask is not None:
|
525 |
+
latent_mask = []
|
526 |
+
if not isinstance(mask, list):
|
527 |
+
tmp_mask = [mask]
|
528 |
+
else:
|
529 |
+
tmp_mask = mask
|
530 |
+
_, l_channels, l_height, l_width = latents.shape
|
531 |
+
for m in tmp_mask:
|
532 |
+
if not isinstance(m, Image.Image):
|
533 |
+
if len(m.shape) == 2:
|
534 |
+
m = m[..., np.newaxis]
|
535 |
+
if m.max() > 1:
|
536 |
+
m = m / 255.0
|
537 |
+
m = self.image_processor.numpy_to_pil(m)[0]
|
538 |
+
if m.mode != "L":
|
539 |
+
m = m.convert("L")
|
540 |
+
resized = self.image_processor.resize(m, l_height, l_width)
|
541 |
+
if self.debug_save:
|
542 |
+
resized.save("latent_mask.png")
|
543 |
+
latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0))
|
544 |
+
latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents)
|
545 |
+
latent_mask = latent_mask / max(latent_mask.max(), 1)
|
546 |
+
return latent_mask
|
547 |
+
|
548 |
+
def prepare_latents(
|
549 |
+
self,
|
550 |
+
image,
|
551 |
+
timestep,
|
552 |
+
batch_size,
|
553 |
+
num_images_per_prompt,
|
554 |
+
dtype,
|
555 |
+
device,
|
556 |
+
generator=None,
|
557 |
+
add_noise=True,
|
558 |
+
sample_mode: str = "sample",
|
559 |
+
):
|
560 |
+
if not isinstance(image, (torch.Tensor, Image.Image, list)):
|
561 |
+
raise ValueError(
|
562 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
563 |
+
)
|
564 |
+
|
565 |
+
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
566 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
567 |
+
self.text_encoder_2.to("cpu")
|
568 |
+
torch.cuda.empty_cache()
|
569 |
+
|
570 |
+
image = image.to(device=device, dtype=dtype)
|
571 |
+
|
572 |
+
batch_size = batch_size * num_images_per_prompt
|
573 |
+
|
574 |
+
if image.shape[1] == 4:
|
575 |
+
init_latents = image
|
576 |
+
elif sample_mode == "random":
|
577 |
+
height, width = image.shape[-2:]
|
578 |
+
num_channels_latents = self.unet.config.in_channels
|
579 |
+
latents = self.random_latents(
|
580 |
+
batch_size,
|
581 |
+
num_channels_latents,
|
582 |
+
height,
|
583 |
+
width,
|
584 |
+
dtype,
|
585 |
+
device,
|
586 |
+
generator,
|
587 |
+
)
|
588 |
+
return self.vae.config.scaling_factor * latents
|
589 |
+
else:
|
590 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
591 |
+
if self.vae.config.force_upcast:
|
592 |
+
image = image.float()
|
593 |
+
self.vae.to(dtype=torch.float32)
|
594 |
+
|
595 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
596 |
+
raise ValueError(
|
597 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
598 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
599 |
+
)
|
600 |
+
|
601 |
+
elif isinstance(generator, list):
|
602 |
+
init_latents = [
|
603 |
+
retrieve_latents(
|
604 |
+
self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode
|
605 |
+
)
|
606 |
+
for i in range(batch_size)
|
607 |
+
]
|
608 |
+
init_latents = torch.cat(init_latents, dim=0)
|
609 |
+
else:
|
610 |
+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode=sample_mode)
|
611 |
+
|
612 |
+
if self.vae.config.force_upcast:
|
613 |
+
self.vae.to(dtype)
|
614 |
+
|
615 |
+
init_latents = init_latents.to(dtype)
|
616 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
617 |
+
|
618 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
619 |
+
# expand init_latents for batch_size
|
620 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
621 |
+
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
622 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
623 |
+
raise ValueError(
|
624 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
625 |
+
)
|
626 |
+
else:
|
627 |
+
init_latents = torch.cat([init_latents], dim=0)
|
628 |
+
|
629 |
+
if add_noise:
|
630 |
+
shape = init_latents.shape
|
631 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
632 |
+
# get latents
|
633 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
634 |
+
|
635 |
+
latents = init_latents
|
636 |
+
|
637 |
+
return latents
|
638 |
+
|
639 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
640 |
+
def random_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
641 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
642 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
643 |
+
raise ValueError(
|
644 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
645 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
646 |
+
)
|
647 |
+
|
648 |
+
if latents is None:
|
649 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
650 |
+
else:
|
651 |
+
latents = latents.to(device)
|
652 |
+
|
653 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
654 |
+
latents = latents * self.scheduler.init_noise_sigma
|
655 |
+
return latents
|
656 |
+
|
657 |
+
def denormalize(self, latents):
|
658 |
+
# unscale/denormalize the latents
|
659 |
+
# denormalize with the mean and std if available and not None
|
660 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
661 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
662 |
+
if has_latents_mean and has_latents_std:
|
663 |
+
latents_mean = (
|
664 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
665 |
+
)
|
666 |
+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
667 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
668 |
+
else:
|
669 |
+
latents = latents / self.vae.config.scaling_factor
|
670 |
+
|
671 |
+
return latents
|
672 |
+
|
673 |
+
def latents_to_img(self, latents):
|
674 |
+
l1 = self.denormalize(latents)
|
675 |
+
img1 = self.vae.decode(l1, return_dict=False)[0]
|
676 |
+
img1 = self.image_processor.postprocess(img1, output_type="pil", do_denormalize=[True])
|
677 |
+
return img1
|
678 |
+
|
679 |
+
def blur_mask(self, pil_mask, blur):
|
680 |
+
mask_blur = pil_mask.filter(ImageFilter.GaussianBlur(radius=blur))
|
681 |
+
mask_blur = np.array(mask_blur)
|
682 |
+
return torch.from_numpy(np.tile(mask_blur / mask_blur.max(), (3, 1, 1)).transpose(1, 2, 0))
|