zhiweili commited on
Commit
4743900
1 Parent(s): df579d7
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .vscode
2
+ .DS_Store
3
+ __pycache__
README.md CHANGED
@@ -10,4 +10,5 @@ pinned: false
10
  license: mit
11
  ---
12
 
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  license: mit
11
  ---
12
 
13
+ Modified from: https://huggingface.co/spaces/turboedit/turbo_edit
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from app_base import create_demo as create_demo_face
4
+
5
+ with gr.Blocks(css="style.css") as demo:
6
+ with gr.Tabs():
7
+ with gr.Tab(label="Face"):
8
+ create_demo_face()
9
+
10
+ demo.launch()
app_base.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import uuid
5
+ import os
6
+
7
+ from PIL import Image
8
+ from enhance_utils import enhance_image
9
+
10
+ DEFAULT_SRC_PROMPT = "a woman, photo"
11
+ DEFAULT_EDIT_PROMPT = "a beautiful woman, photo, hollywood style face, 8k, high quality"
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ def create_demo() -> gr.Blocks:
16
+ from inversion_run_base import run as base_run
17
+
18
+ @spaces.GPU(duration=10)
19
+ def image_to_image(
20
+ input_image: Image,
21
+ input_image_prompt: str,
22
+ edit_prompt: str,
23
+ seed: int,
24
+ w1: float,
25
+ num_steps: int,
26
+ start_step: int,
27
+ guidance_scale: float,
28
+ enhance_face: bool = True,
29
+ ):
30
+ w2 = 1.0
31
+
32
+ run_model = base_run
33
+ res_image = run_model(
34
+ input_image,
35
+ input_image_prompt,
36
+ edit_prompt,
37
+ seed,
38
+ w1,
39
+ w2,
40
+ num_steps,
41
+ start_step,
42
+ guidance_scale,
43
+ )
44
+ enhanced_image = enhance_image(res_image, enhance_face)
45
+
46
+ tmpPrefix = "/tmp/gradio/"
47
+
48
+ extension = 'png'
49
+ if enhanced_image.mode == 'RGBA':
50
+ extension = 'png'
51
+ else:
52
+ extension = 'jpg'
53
+
54
+ targetDir = f"{tmpPrefix}output/"
55
+ if not os.path.exists(targetDir):
56
+ os.makedirs(targetDir)
57
+
58
+ enhanced_path = f"{targetDir}{uuid.uuid4()}.{extension}"
59
+ enhanced_image.save(enhanced_path, quality=100)
60
+
61
+ return enhanced_path
62
+
63
+ with gr.Blocks() as demo:
64
+ with gr.Row():
65
+ with gr.Column():
66
+ input_image_path = gr.File(label="Input Image")
67
+ with gr.Column():
68
+ generated_image_path = gr.File(label="Download the segment image", interactive=False)
69
+ with gr.Row():
70
+ with gr.Column():
71
+ input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
72
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
73
+ with gr.Accordion("Advanced Options", open=False):
74
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
75
+ enhance_face = gr.Checkbox(label="Enhance Face", value=False)
76
+ seed = gr.Number(label="Seed", value=8)
77
+ with gr.Column():
78
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
79
+ start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
80
+ w1 = gr.Number(label="W1", value=2)
81
+ g_btn = gr.Button("Edit Image")
82
+
83
+
84
+ g_btn.click(
85
+ fn=image_to_image,
86
+ inputs=[input_image_path, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, enhance_face],
87
+ outputs=[generated_image_path],
88
+ )
89
+
90
+ return demo
config.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ml_collections import config_dict
2
+ import yaml
3
+ from diffusers.schedulers import (
4
+ DDIMScheduler,
5
+ EulerAncestralDiscreteScheduler,
6
+ EulerDiscreteScheduler,
7
+ DDPMScheduler,
8
+ )
9
+ from inversion_utils import (
10
+ deterministic_ddim_step,
11
+ deterministic_ddpm_step,
12
+ deterministic_euler_step,
13
+ deterministic_non_ancestral_euler_step,
14
+ )
15
+
16
+ BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"]
17
+ SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"]
18
+ MODELS = [
19
+ "stabilityai/sdxl-turbo",
20
+ "stabilityai/stable-diffusion-xl-base-1.0",
21
+ "CompVis/stable-diffusion-v1-4",
22
+ ]
23
+
24
+ def get_num_steps_actual(cfg):
25
+ return (
26
+ cfg.num_steps_inversion
27
+ - cfg.step_start
28
+ + (1 if cfg.clean_step_timestep > 0 else 0)
29
+ if cfg.timesteps is None
30
+ else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0)
31
+ )
32
+
33
+
34
+ def get_config(args):
35
+ if args.config_from_file and args.config_from_file != "":
36
+ with open(args.config_from_file, "r") as f:
37
+ cfg = config_dict.ConfigDict(yaml.safe_load(f))
38
+
39
+ num_steps_actual = get_num_steps_actual(cfg)
40
+
41
+ else:
42
+ cfg = config_dict.ConfigDict()
43
+
44
+ cfg.seed = 2
45
+ cfg.self_r = 0.5
46
+ cfg.cross_r = 0.9
47
+ cfg.eta = 1
48
+ cfg.scheduler_type = SCHEDULERS[0]
49
+
50
+ cfg.num_steps_inversion = 50 # timesteps: 999, 799, 599, 399, 199
51
+ cfg.step_start = 20
52
+ cfg.timesteps = None
53
+ cfg.noise_timesteps = None
54
+ num_steps_actual = get_num_steps_actual(cfg)
55
+ cfg.ws1 = [2] * num_steps_actual
56
+ cfg.ws2 = [1] * num_steps_actual
57
+ cfg.real_cfg_scale = 0
58
+ cfg.real_cfg_scale_save = 0
59
+ cfg.breakdown = BREAKDOWNS[1]
60
+ cfg.noise_shift_delta = 1
61
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
62
+
63
+ cfg.clean_step_timestep = 0
64
+
65
+ cfg.model = MODELS[1]
66
+
67
+ if cfg.scheduler_type == "ddim":
68
+ cfg.scheduler_class = DDIMScheduler
69
+ cfg.step_function = deterministic_ddim_step
70
+ elif cfg.scheduler_type == "ddpm":
71
+ cfg.scheduler_class = DDPMScheduler
72
+ cfg.step_function = deterministic_ddpm_step
73
+ elif cfg.scheduler_type == "euler":
74
+ cfg.scheduler_class = EulerAncestralDiscreteScheduler
75
+ cfg.step_function = deterministic_euler_step
76
+ elif cfg.scheduler_type == "euler_non_ancestral":
77
+ cfg.scheduler_class = EulerDiscreteScheduler
78
+ cfg.step_function = deterministic_non_ancestral_euler_step
79
+ else:
80
+ raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}")
81
+
82
+ with cfg.ignore_type():
83
+ if isinstance(cfg.max_norm_zs, (int, float)):
84
+ cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual
85
+
86
+ if isinstance(cfg.ws1, (int, float)):
87
+ cfg.ws1 = [cfg.ws1] * num_steps_actual
88
+
89
+ if isinstance(cfg.ws2, (int, float)):
90
+ cfg.ws2 = [cfg.ws2] * num_steps_actual
91
+
92
+ if not hasattr(cfg, "update_eta"):
93
+ cfg.update_eta = False
94
+
95
+ if not hasattr(cfg, "save_timesteps"):
96
+ cfg.save_timesteps = None
97
+
98
+ if not hasattr(cfg, "scheduler_timesteps"):
99
+ cfg.scheduler_timesteps = None
100
+
101
+ assert (
102
+ cfg.scheduler_type == "ddpm" or cfg.timesteps is None
103
+ ), "timesteps must be None for ddim/euler"
104
+
105
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
106
+ assert (
107
+ len(cfg.max_norm_zs) == num_steps_actual
108
+ ), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})"
109
+
110
+ assert (
111
+ len(cfg.ws1) == num_steps_actual
112
+ ), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})"
113
+
114
+ assert (
115
+ len(cfg.ws2) == num_steps_actual
116
+ ), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})"
117
+
118
+ assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == (
119
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
120
+ ), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})"
121
+
122
+ assert cfg.save_timesteps is None or len(cfg.save_timesteps) == (
123
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
124
+ ), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})"
125
+
126
+ return cfg
127
+
128
+
129
+ def get_config_name(config, args):
130
+ if args.folder_name is not None and args.folder_name != "":
131
+ return args.folder_name
132
+ timesteps_str = (
133
+ f"step_start {config.step_start}"
134
+ if config.timesteps is None
135
+ else f"timesteps {config.timesteps}"
136
+ )
137
+ return f"""\
138
+ ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \
139
+ real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \
140
+ scheduler_type {config.scheduler_type} fp16 {args.fp16}\
141
+ """
enhance_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import subprocess
6
+
7
+ from PIL import Image
8
+ from gfpgan.utils import GFPGANer
9
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
10
+ from realesrgan.utils import RealESRGANer
11
+
12
+ def runcmd(cmd, verbose = False, *args, **kwargs):
13
+
14
+ process = subprocess.Popen(
15
+ cmd,
16
+ stdout = subprocess.PIPE,
17
+ stderr = subprocess.PIPE,
18
+ text = True,
19
+ shell = True
20
+ )
21
+ std_out, std_err = process.communicate()
22
+ if verbose:
23
+ print(std_out.strip(), std_err)
24
+ pass
25
+
26
+ runcmd("pip freeze")
27
+ if not os.path.exists('GFPGANv1.4.pth'):
28
+ runcmd("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
29
+ if not os.path.exists('realesr-general-x4v3.pth'):
30
+ runcmd("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
31
+
32
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
33
+ model_path = 'realesr-general-x4v3.pth'
34
+ half = True if torch.cuda.is_available() else False
35
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
36
+
37
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2)
38
+
39
+ def enhance_image(
40
+ pil_image: Image,
41
+ enhance_face: bool = True,
42
+ ):
43
+ img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
44
+
45
+ h, w = img.shape[0:2]
46
+ if h < 300:
47
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
48
+ if enhance_face:
49
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
50
+ else:
51
+ output, _ = upsampler.enhance(img, outscale=2)
52
+ pil_output = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
53
+
54
+ return pil_output
inversion_run_base.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers import (
4
+ DDPMScheduler,
5
+ StableDiffusionXLImg2ImgPipeline,
6
+ )
7
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
8
+ from PIL import Image
9
+ from inversion_utils import get_ddpm_inversion_scheduler, create_xts
10
+ from config import get_config, get_num_steps_actual
11
+ from functools import partial
12
+ from compel import Compel, ReturnedEmbeddingsType
13
+
14
+ class Object(object):
15
+ pass
16
+
17
+ args = Object()
18
+ args.images_paths = None
19
+ args.images_folder = None
20
+ args.force_use_cpu = False
21
+ args.folder_name = 'test_measure_time'
22
+ args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml'
23
+ args.save_intermediate_results = False
24
+ args.batch_size = None
25
+ args.skip_p_to_p = True
26
+ args.only_p_to_p = False
27
+ args.fp16 = False
28
+ args.prompts_file = 'dataset_measure_time/dataset.json'
29
+ args.images_in_prompts_file = None
30
+ args.seed = 986
31
+ args.time_measure_n = 1
32
+
33
+
34
+ assert (
35
+ args.batch_size is None or args.save_intermediate_results is False
36
+ ), "save_intermediate_results is not implemented for batch_size > 1"
37
+
38
+ generator = None
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+
41
+ # BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
42
+ BASE_MODEL = "stabilityai/sdxl-turbo"
43
+
44
+
45
+ pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
46
+ BASE_MODEL,
47
+ torch_dtype=torch.float16,
48
+ variant="fp16",
49
+ use_safetensors=True,
50
+ )
51
+ pipeline = pipeline.to(device)
52
+
53
+ pipeline.scheduler = DDPMScheduler.from_pretrained(
54
+ BASE_MODEL,
55
+ subfolder="scheduler",
56
+ )
57
+
58
+ config = get_config(args)
59
+
60
+ compel_proc = Compel(
61
+ tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
62
+ text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
63
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
64
+ requires_pooled=[False, True]
65
+ )
66
+
67
+ def run(
68
+ input_image:Image,
69
+ src_prompt:str,
70
+ tgt_prompt:str,
71
+ seed:int,
72
+ w1:float,
73
+ w2:float,
74
+ num_steps:int,
75
+ start_step:int,
76
+ guidance_scale:float,
77
+ ):
78
+ generator = torch.Generator().manual_seed(seed)
79
+
80
+ config.num_steps_inversion = num_steps
81
+ config.step_start = start_step
82
+ num_steps_actual = get_num_steps_actual(config)
83
+
84
+
85
+ num_steps_inversion = config.num_steps_inversion
86
+ denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion
87
+ print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} denoising_start: {denoising_start}")
88
+
89
+ timesteps, num_inference_steps = retrieve_timesteps(
90
+ pipeline.scheduler, num_steps_inversion, device, None
91
+ )
92
+ timesteps, num_inference_steps = pipeline.get_timesteps(
93
+ num_inference_steps=num_inference_steps,
94
+ denoising_start=denoising_start,
95
+ strength=0,
96
+ device=device,
97
+ )
98
+ timesteps = timesteps.type(torch.int64)
99
+
100
+ timesteps = [torch.tensor(t) for t in timesteps.tolist()]
101
+ timesteps_len = len(timesteps)
102
+ config.step_start = start_step + num_steps_actual - timesteps_len
103
+ num_steps_actual = timesteps_len
104
+ config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
105
+ print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} step_start: {config.step_start}")
106
+ print(f"-------->timesteps len: {len(timesteps)} max_norm_zs len: {len(config.max_norm_zs)}")
107
+ pipeline.__call__ = partial(
108
+ pipeline.__call__,
109
+ num_inference_steps=num_steps_inversion,
110
+ guidance_scale=guidance_scale,
111
+ generator=generator,
112
+ denoising_start=denoising_start,
113
+ strength=0,
114
+ )
115
+
116
+ x_0_image = input_image
117
+ x_0 = encode_image(x_0_image, pipeline)
118
+ x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
119
+ x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
120
+ latents = [x_ts[0]]
121
+ x_ts_c_hat = [None]
122
+ config.ws1 = [w1] * num_steps_actual
123
+ config.ws2 = [w2] * num_steps_actual
124
+ pipeline.scheduler = get_ddpm_inversion_scheduler(
125
+ pipeline.scheduler,
126
+ config.step_function,
127
+ config,
128
+ timesteps,
129
+ config.save_timesteps,
130
+ latents,
131
+ x_ts,
132
+ x_ts_c_hat,
133
+ args.save_intermediate_results,
134
+ pipeline,
135
+ x_0,
136
+ v1s_images := [],
137
+ v2s_images := [],
138
+ deltas_images := [],
139
+ v1_x0s := [],
140
+ v2_x0s := [],
141
+ deltas_x0s := [],
142
+ "res12",
143
+ image_name="im_name",
144
+ time_measure_n=args.time_measure_n,
145
+ )
146
+ latent = latents[0].expand(3, -1, -1, -1)
147
+ prompt = [src_prompt, src_prompt, tgt_prompt]
148
+ conditioning, pooled = compel_proc(prompt)
149
+ image = pipeline.__call__(
150
+ image=latent,
151
+ prompt_embeds=conditioning,
152
+ pooled_prompt_embeds=pooled,
153
+ eta=1,
154
+ ).images
155
+ return image[2]
156
+
157
+ def encode_image(image, pipe):
158
+ image = pipe.image_processor.preprocess(image)
159
+ originDtype = pipe.dtype
160
+ image = image.to(device=device, dtype=originDtype)
161
+
162
+ if pipe.vae.config.force_upcast:
163
+ image = image.float()
164
+ pipe.vae.to(dtype=torch.float32)
165
+
166
+ if isinstance(generator, list):
167
+ init_latents = [
168
+ retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
169
+ for i in range(1)
170
+ ]
171
+ init_latents = torch.cat(init_latents, dim=0)
172
+ else:
173
+ init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)
174
+
175
+ if pipe.vae.config.force_upcast:
176
+ pipe.vae.to(originDtype)
177
+
178
+ init_latents = init_latents.to(originDtype)
179
+ init_latents = pipe.vae.config.scaling_factor * init_latents
180
+
181
+ return init_latents.to(dtype=torch.float16)
182
+
183
+ def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None):
184
+ # get the original timestep using init_timestep
185
+ if denoising_start is None:
186
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
187
+ t_start = max(num_inference_steps - init_timestep, 0)
188
+ else:
189
+ t_start = 0
190
+
191
+ timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :]
192
+
193
+ # Strength is irrelevant if we directly request a timestep to start at;
194
+ # that is, strength is determined by the denoising_start instead.
195
+ if denoising_start is not None:
196
+ discrete_timestep_cutoff = int(
197
+ round(
198
+ pipe.scheduler.config.num_train_timesteps
199
+ - (denoising_start * pipe.scheduler.config.num_train_timesteps)
200
+ )
201
+ )
202
+
203
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
204
+ if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0:
205
+ # if the scheduler is a 2nd order scheduler we might have to do +1
206
+ # because `num_inference_steps` might be even given that every timestep
207
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
208
+ # mean that we cut the timesteps in the middle of the denoising step
209
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
210
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
211
+ num_inference_steps = num_inference_steps + 1
212
+
213
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
214
+ timesteps = timesteps[-num_inference_steps:]
215
+ return timesteps, num_inference_steps
216
+
217
+ return timesteps, num_inference_steps - t_start
inversion_utils.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import PIL
4
+
5
+ from typing import List, Optional, Union
6
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
7
+ from PIL import Image
8
+ from diffusers.utils import logging
9
+
10
+ VECTOR_DATA_FOLDER = "vector_data"
11
+ VECTOR_DATA_DICT = "vector_data"
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ def get_ddpm_inversion_scheduler(
16
+ scheduler,
17
+ step_function,
18
+ config,
19
+ timesteps,
20
+ save_timesteps,
21
+ latents,
22
+ x_ts,
23
+ x_ts_c_hat,
24
+ save_intermediate_results,
25
+ pipe,
26
+ x_0,
27
+ v1s_images,
28
+ v2s_images,
29
+ deltas_images,
30
+ v1_x0s,
31
+ v2_x0s,
32
+ deltas_x0s,
33
+ folder_name,
34
+ image_name,
35
+ time_measure_n,
36
+ ):
37
+ def step(
38
+ model_output: torch.FloatTensor,
39
+ timestep: int,
40
+ sample: torch.FloatTensor,
41
+ eta: float = 0.0,
42
+ use_clipped_model_output: bool = False,
43
+ generator=None,
44
+ variance_noise: Optional[torch.FloatTensor] = None,
45
+ return_dict: bool = True,
46
+ ):
47
+ # if scheduler.is_save:
48
+ # start = timer()
49
+ res_inv = step_save_latents(
50
+ scheduler,
51
+ model_output[:1, :, :, :],
52
+ timestep,
53
+ sample[:1, :, :, :],
54
+ eta,
55
+ use_clipped_model_output,
56
+ generator,
57
+ variance_noise,
58
+ return_dict,
59
+ )
60
+ # end = timer()
61
+ # print(f"Run Time Inv: {end - start}")
62
+
63
+ res_inf = step_use_latents(
64
+ scheduler,
65
+ model_output[1:, :, :, :],
66
+ timestep,
67
+ sample[1:, :, :, :],
68
+ eta,
69
+ use_clipped_model_output,
70
+ generator,
71
+ variance_noise,
72
+ return_dict,
73
+ )
74
+ # res = res_inv
75
+ res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
76
+ return res
77
+ # return res
78
+
79
+ scheduler.step_function = step_function
80
+ scheduler.is_save = True
81
+ scheduler._timesteps = timesteps
82
+ scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
83
+ scheduler._config = config
84
+ scheduler.latents = latents
85
+ scheduler.x_ts = x_ts
86
+ scheduler.x_ts_c_hat = x_ts_c_hat
87
+ scheduler.step = step
88
+ scheduler.save_intermediate_results = save_intermediate_results
89
+ scheduler.pipe = pipe
90
+ scheduler.v1s_images = v1s_images
91
+ scheduler.v2s_images = v2s_images
92
+ scheduler.deltas_images = deltas_images
93
+ scheduler.v1_x0s = v1_x0s
94
+ scheduler.v2_x0s = v2_x0s
95
+ scheduler.deltas_x0s = deltas_x0s
96
+ scheduler.clean_step_run = False
97
+ scheduler.x_0s = create_xts(
98
+ config.noise_shift_delta,
99
+ config.noise_timesteps,
100
+ config.clean_step_timestep,
101
+ None,
102
+ pipe.scheduler,
103
+ timesteps,
104
+ x_0,
105
+ no_add_noise=True,
106
+ )
107
+ scheduler.folder_name = folder_name
108
+ scheduler.image_name = image_name
109
+ scheduler.p_to_p = False
110
+ scheduler.p_to_p_replace = False
111
+ scheduler.time_measure_n = time_measure_n
112
+ return scheduler
113
+
114
+ def step_save_latents(
115
+ self,
116
+ model_output: torch.FloatTensor,
117
+ timestep: int,
118
+ sample: torch.FloatTensor,
119
+ eta: float = 0.0,
120
+ use_clipped_model_output: bool = False,
121
+ generator=None,
122
+ variance_noise: Optional[torch.FloatTensor] = None,
123
+ return_dict: bool = True,
124
+ ):
125
+ # print(self._save_timesteps)
126
+ # timestep_index = map_timpstep_to_index[timestep]
127
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
128
+ timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
129
+ next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
130
+ u_hat_t = self.step_function(
131
+ model_output=model_output,
132
+ timestep=timestep,
133
+ sample=sample,
134
+ eta=eta,
135
+ use_clipped_model_output=use_clipped_model_output,
136
+ generator=generator,
137
+ variance_noise=variance_noise,
138
+ return_dict=False,
139
+ scheduler=self,
140
+ )
141
+
142
+ x_t_minus_1 = self.x_ts[next_timestep_index]
143
+ self.x_ts_c_hat.append(u_hat_t)
144
+
145
+ z_t = x_t_minus_1 - u_hat_t
146
+ self.latents.append(z_t)
147
+ z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)
148
+
149
+ x_t_minus_1_predicted = u_hat_t + z_t
150
+
151
+ if not return_dict:
152
+ return (x_t_minus_1_predicted,)
153
+
154
+ return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
155
+
156
+ def step_use_latents(
157
+ self,
158
+ model_output: torch.FloatTensor,
159
+ timestep: int,
160
+ sample: torch.FloatTensor,
161
+ eta: float = 0.0,
162
+ use_clipped_model_output: bool = False,
163
+ generator=None,
164
+ variance_noise: Optional[torch.FloatTensor] = None,
165
+ return_dict: bool = True,
166
+ ):
167
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
168
+ timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
169
+ next_timestep_index = (
170
+ timestep_index + 1 if not self.clean_step_run else -1
171
+ )
172
+ z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T
173
+
174
+ _, normalize_coefficient = normalize(
175
+ z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
176
+ timestep_index,
177
+ self._config.max_norm_zs,
178
+ )
179
+
180
+ if normalize_coefficient == 0:
181
+ eta = 0
182
+
183
+ # eta = normalize_coefficient
184
+
185
+ x_t_hat_c_hat = self.step_function(
186
+ model_output=model_output,
187
+ timestep=timestep,
188
+ sample=sample,
189
+ eta=eta,
190
+ use_clipped_model_output=use_clipped_model_output,
191
+ generator=generator,
192
+ variance_noise=variance_noise,
193
+ return_dict=False,
194
+ scheduler=self,
195
+ )
196
+
197
+ w1 = self._config.ws1[timestep_index]
198
+ w2 = self._config.ws2[timestep_index]
199
+
200
+ x_t_minus_1_exact = self.x_ts[next_timestep_index]
201
+ x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
202
+
203
+ x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
204
+ if self._config.breakdown == "x_t_c_hat":
205
+ raise NotImplementedError("breakdown x_t_c_hat not implemented yet")
206
+
207
+ # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
208
+ x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
209
+
210
+ # if self._config.breakdown == "x_t_c_hat":
211
+ # v1 = x_t_hat_c_hat - x_t_c_hat
212
+ # v2 = x_t_c_hat - x_t_c
213
+ if (
214
+ self._config.breakdown == "x_t_hat_c"
215
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
216
+ ):
217
+ zero_index_reconstruction = 1 if not self.time_measure_n else 0
218
+ edit_prompts_num = (
219
+ (model_output.size(0) - zero_index_reconstruction) // 3
220
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
221
+ else (model_output.size(0) - zero_index_reconstruction) // 2
222
+ )
223
+ x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
224
+ edit_images_indices = (
225
+ edit_prompts_num + zero_index_reconstruction,
226
+ (
227
+ model_output.size(0)
228
+ if self._config.breakdown == "x_t_hat_c"
229
+ else zero_index_reconstruction + 2 * edit_prompts_num
230
+ ),
231
+ )
232
+ x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
233
+ x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
234
+ x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
235
+ ]
236
+ v1 = x_t_hat_c_hat - x_t_hat_c
237
+ v2 = x_t_hat_c - normalize_coefficient * x_t_c
238
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
239
+ path = os.path.join(
240
+ self.folder_name,
241
+ VECTOR_DATA_FOLDER,
242
+ self.image_name,
243
+ )
244
+ if not hasattr(self, VECTOR_DATA_DICT):
245
+ os.makedirs(path, exist_ok=True)
246
+ self.vector_data = dict()
247
+
248
+ x_t_0 = x_t_c_hat[1]
249
+ empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
250
+ x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]
251
+
252
+ self.vector_data[timestep.item()] = dict()
253
+ self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
254
+ edit_images_indices[0] : edit_images_indices[1]
255
+ ]
256
+ self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
257
+ self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
258
+ self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
259
+ self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
260
+ edit_images_indices[0] : edit_images_indices[1]
261
+ ]
262
+ self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
263
+ 0
264
+ ].expand_as(x_t_hat_0)
265
+ self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
266
+ next_timestep_index
267
+ ].expand_as(x_t_hat_0)
268
+
269
+ else: # no breakdown
270
+ v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
271
+ v2 = 0
272
+
273
+ if self.save_intermediate_results and not self.p_to_p:
274
+ delta = v1 + v2
275
+ v1_plus_x0 = self.x_0s[next_timestep_index] + v1
276
+ v2_plus_x0 = self.x_0s[next_timestep_index] + v2
277
+ delta_plus_x0 = self.x_0s[next_timestep_index] + delta
278
+
279
+ v1_images = decode_latents(v1, self.pipe)
280
+ self.v1s_images.append(v1_images)
281
+ v2_images = (
282
+ decode_latents(v2, self.pipe)
283
+ if self._config.breakdown != "no_breakdown"
284
+ else [PIL.Image.new("RGB", (1, 1))]
285
+ )
286
+ self.v2s_images.append(v2_images)
287
+ delta_images = decode_latents(delta, self.pipe)
288
+ self.deltas_images.append(delta_images)
289
+ v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
290
+ self.v1_x0s.append(v1_plus_x0_images)
291
+ v2_plus_x0_images = (
292
+ decode_latents(v2_plus_x0, self.pipe)
293
+ if self._config.breakdown != "no_breakdown"
294
+ else [PIL.Image.new("RGB", (1, 1))]
295
+ )
296
+ self.v2_x0s.append(v2_plus_x0_images)
297
+ delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
298
+ self.deltas_x0s.append(delta_plus_x0_images)
299
+
300
+ # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
301
+ # if self._config.breakdown != "no_breakdown":
302
+ # print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
303
+ # print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")
304
+
305
+ x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
306
+
307
+ if (
308
+ self._config.breakdown == "x_t_hat_c"
309
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
310
+ ):
311
+ x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
312
+ edit_images_indices[0] : edit_images_indices[1]
313
+ ] # update x_t_hat_c to be x_t_hat_c_hat
314
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
315
+ x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
316
+ x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
317
+ )
318
+ self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
319
+ edit_images_indices[0] : edit_images_indices[1]
320
+ ]
321
+ if timestep == self._timesteps[-1]:
322
+ torch.save(
323
+ self.vector_data,
324
+ os.path.join(
325
+ path,
326
+ f"{VECTOR_DATA_DICT}.pt",
327
+ ),
328
+ )
329
+ # p_to_p_force_perfect_reconstruction
330
+ if not self.time_measure_n:
331
+ x_t_minus_1[0] = x_t_minus_1_exact[0]
332
+
333
+ if not return_dict:
334
+ return (x_t_minus_1,)
335
+
336
+ return DDIMSchedulerOutput(
337
+ prev_sample=x_t_minus_1,
338
+ pred_original_sample=None,
339
+ )
340
+
341
+ def create_xts(
342
+ noise_shift_delta,
343
+ noise_timesteps,
344
+ clean_step_timestep,
345
+ generator,
346
+ scheduler,
347
+ timesteps,
348
+ x_0,
349
+ no_add_noise=False,
350
+ ):
351
+ if noise_timesteps is None:
352
+ noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
353
+ noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
354
+
355
+ first_x_0_idx = len(noise_timesteps)
356
+ for i in range(len(noise_timesteps)):
357
+ if noise_timesteps[i] <= 0:
358
+ first_x_0_idx = i
359
+ break
360
+
361
+ noise_timesteps = noise_timesteps[:first_x_0_idx]
362
+
363
+ x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
364
+ noise = (
365
+ torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
366
+ x_0.device
367
+ )
368
+ if not no_add_noise
369
+ else torch.zeros_like(x_0_expanded)
370
+ )
371
+ x_ts = scheduler.add_noise(
372
+ x_0_expanded,
373
+ noise,
374
+ torch.IntTensor(noise_timesteps),
375
+ )
376
+ x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
377
+ x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
378
+ x_ts += [x_0]
379
+ if clean_step_timestep > 0:
380
+ x_ts += [x_0]
381
+ return x_ts
382
+
383
+ def normalize(
384
+ z_t,
385
+ i,
386
+ max_norm_zs,
387
+ ):
388
+ max_norm = max_norm_zs[i]
389
+ if max_norm < 0:
390
+ return z_t, 1
391
+
392
+ norm = torch.norm(z_t)
393
+ if norm < max_norm:
394
+ return z_t, 1
395
+
396
+ coeff = max_norm / norm
397
+ z_t = z_t * coeff
398
+ return z_t, coeff
399
+
400
+ def decode_latents(latent, pipe):
401
+ latent_img = pipe.vae.decode(
402
+ latent / pipe.vae.config.scaling_factor, return_dict=False
403
+ )[0]
404
+ return pipe.image_processor.postprocess(latent_img, output_type="pil")
405
+
406
+ def deterministic_ddim_step(
407
+ model_output: torch.FloatTensor,
408
+ timestep: int,
409
+ sample: torch.FloatTensor,
410
+ eta: float = 0.0,
411
+ use_clipped_model_output: bool = False,
412
+ generator=None,
413
+ variance_noise: Optional[torch.FloatTensor] = None,
414
+ return_dict: bool = True,
415
+ scheduler=None,
416
+ ):
417
+
418
+ if scheduler.num_inference_steps is None:
419
+ raise ValueError(
420
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
421
+ )
422
+
423
+ prev_timestep = (
424
+ timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
425
+ )
426
+
427
+ # 2. compute alphas, betas
428
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
429
+ alpha_prod_t_prev = (
430
+ scheduler.alphas_cumprod[prev_timestep]
431
+ if prev_timestep >= 0
432
+ else scheduler.final_alpha_cumprod
433
+ )
434
+
435
+ beta_prod_t = 1 - alpha_prod_t
436
+
437
+ if scheduler.config.prediction_type == "epsilon":
438
+ pred_original_sample = (
439
+ sample - beta_prod_t ** (0.5) * model_output
440
+ ) / alpha_prod_t ** (0.5)
441
+ pred_epsilon = model_output
442
+ elif scheduler.config.prediction_type == "sample":
443
+ pred_original_sample = model_output
444
+ pred_epsilon = (
445
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
446
+ ) / beta_prod_t ** (0.5)
447
+ elif scheduler.config.prediction_type == "v_prediction":
448
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
449
+ beta_prod_t**0.5
450
+ ) * model_output
451
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
452
+ else:
453
+ raise ValueError(
454
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
455
+ " `v_prediction`"
456
+ )
457
+
458
+ # 4. Clip or threshold "predicted x_0"
459
+ if scheduler.config.thresholding:
460
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
461
+ elif scheduler.config.clip_sample:
462
+ pred_original_sample = pred_original_sample.clamp(
463
+ -scheduler.config.clip_sample_range,
464
+ scheduler.config.clip_sample_range,
465
+ )
466
+
467
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
468
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
469
+ variance = scheduler._get_variance(timestep, prev_timestep)
470
+ std_dev_t = eta * variance ** (0.5)
471
+
472
+ if use_clipped_model_output:
473
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
474
+ pred_epsilon = (
475
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
476
+ ) / beta_prod_t ** (0.5)
477
+
478
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
479
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
480
+ 0.5
481
+ ) * pred_epsilon
482
+
483
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
484
+ prev_sample = (
485
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
486
+ )
487
+ return prev_sample
488
+
489
+
490
+ def deterministic_euler_step(
491
+ model_output: torch.FloatTensor,
492
+ timestep: Union[float, torch.FloatTensor],
493
+ sample: torch.FloatTensor,
494
+ eta,
495
+ use_clipped_model_output,
496
+ generator,
497
+ variance_noise,
498
+ return_dict,
499
+ scheduler,
500
+ ):
501
+ """
502
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
503
+ process from the learned model outputs (most often the predicted noise).
504
+
505
+ Args:
506
+ model_output (`torch.FloatTensor`):
507
+ The direct output from learned diffusion model.
508
+ timestep (`float`):
509
+ The current discrete timestep in the diffusion chain.
510
+ sample (`torch.FloatTensor`):
511
+ A current instance of a sample created by the diffusion process.
512
+ generator (`torch.Generator`, *optional*):
513
+ A random number generator.
514
+ return_dict (`bool`):
515
+ Whether or not to return a
516
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
517
+
518
+ Returns:
519
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
520
+ If return_dict is `True`,
521
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
522
+ otherwise a tuple is returned where the first element is the sample tensor.
523
+
524
+ """
525
+
526
+ if (
527
+ isinstance(timestep, int)
528
+ or isinstance(timestep, torch.IntTensor)
529
+ or isinstance(timestep, torch.LongTensor)
530
+ ):
531
+ raise ValueError(
532
+ (
533
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
534
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
535
+ " one of the `scheduler.timesteps` as a timestep."
536
+ ),
537
+ )
538
+
539
+ if scheduler.step_index is None:
540
+ scheduler._init_step_index(timestep)
541
+
542
+ sigma = scheduler.sigmas[scheduler.step_index]
543
+
544
+ # Upcast to avoid precision issues when computing prev_sample
545
+ sample = sample.to(torch.float32)
546
+
547
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
548
+ if scheduler.config.prediction_type == "epsilon":
549
+ pred_original_sample = sample - sigma * model_output
550
+ elif scheduler.config.prediction_type == "v_prediction":
551
+ # * c_out + input * c_skip
552
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
553
+ sample / (sigma**2 + 1)
554
+ )
555
+ elif scheduler.config.prediction_type == "sample":
556
+ raise NotImplementedError("prediction_type not implemented yet: sample")
557
+ else:
558
+ raise ValueError(
559
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
560
+ )
561
+
562
+ sigma_from = scheduler.sigmas[scheduler.step_index]
563
+ sigma_to = scheduler.sigmas[scheduler.step_index + 1]
564
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
565
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
566
+
567
+ # 2. Convert to an ODE derivative
568
+ derivative = (sample - pred_original_sample) / sigma
569
+
570
+ dt = sigma_down - sigma
571
+
572
+ prev_sample = sample + derivative * dt
573
+
574
+ # Cast sample back to model compatible dtype
575
+ prev_sample = prev_sample.to(model_output.dtype)
576
+
577
+ # upon completion increase step index by one
578
+ scheduler._step_index += 1
579
+
580
+ return prev_sample
581
+
582
+
583
+ def deterministic_non_ancestral_euler_step(
584
+ model_output: torch.FloatTensor,
585
+ timestep: Union[float, torch.FloatTensor],
586
+ sample: torch.FloatTensor,
587
+ eta: float = 0.0,
588
+ use_clipped_model_output: bool = False,
589
+ s_churn: float = 0.0,
590
+ s_tmin: float = 0.0,
591
+ s_tmax: float = float("inf"),
592
+ s_noise: float = 1.0,
593
+ generator: Optional[torch.Generator] = None,
594
+ variance_noise: Optional[torch.FloatTensor] = None,
595
+ return_dict: bool = True,
596
+ scheduler=None,
597
+ ):
598
+ """
599
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
600
+ process from the learned model outputs (most often the predicted noise).
601
+
602
+ Args:
603
+ model_output (`torch.FloatTensor`):
604
+ The direct output from learned diffusion model.
605
+ timestep (`float`):
606
+ The current discrete timestep in the diffusion chain.
607
+ sample (`torch.FloatTensor`):
608
+ A current instance of a sample created by the diffusion process.
609
+ s_churn (`float`):
610
+ s_tmin (`float`):
611
+ s_tmax (`float`):
612
+ s_noise (`float`, defaults to 1.0):
613
+ Scaling factor for noise added to the sample.
614
+ generator (`torch.Generator`, *optional*):
615
+ A random number generator.
616
+ return_dict (`bool`):
617
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
618
+ tuple.
619
+
620
+ Returns:
621
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
622
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
623
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
624
+ """
625
+
626
+ if (
627
+ isinstance(timestep, int)
628
+ or isinstance(timestep, torch.IntTensor)
629
+ or isinstance(timestep, torch.LongTensor)
630
+ ):
631
+ raise ValueError(
632
+ (
633
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
634
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
635
+ " one of the `scheduler.timesteps` as a timestep."
636
+ ),
637
+ )
638
+
639
+ if not scheduler.is_scale_input_called:
640
+ logger.warning(
641
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
642
+ "See `StableDiffusionPipeline` for a usage example."
643
+ )
644
+
645
+ if scheduler.step_index is None:
646
+ scheduler._init_step_index(timestep)
647
+
648
+ # Upcast to avoid precision issues when computing prev_sample
649
+ sample = sample.to(torch.float32)
650
+
651
+ sigma = scheduler.sigmas[scheduler.step_index]
652
+
653
+ gamma = (
654
+ min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
655
+ if s_tmin <= sigma <= s_tmax
656
+ else 0.0
657
+ )
658
+
659
+ sigma_hat = sigma * (gamma + 1)
660
+
661
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
662
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
663
+ # backwards compatibility
664
+ if (
665
+ scheduler.config.prediction_type == "original_sample"
666
+ or scheduler.config.prediction_type == "sample"
667
+ ):
668
+ pred_original_sample = model_output
669
+ elif scheduler.config.prediction_type == "epsilon":
670
+ pred_original_sample = sample - sigma_hat * model_output
671
+ elif scheduler.config.prediction_type == "v_prediction":
672
+ # denoised = model_output * c_out + input * c_skip
673
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
674
+ sample / (sigma**2 + 1)
675
+ )
676
+ else:
677
+ raise ValueError(
678
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
679
+ )
680
+
681
+ # 2. Convert to an ODE derivative
682
+ derivative = (sample - pred_original_sample) / sigma_hat
683
+
684
+ dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat
685
+
686
+ prev_sample = sample + derivative * dt
687
+
688
+ # Cast sample back to model compatible dtype
689
+ prev_sample = prev_sample.to(model_output.dtype)
690
+
691
+ # upon completion increase step index by one
692
+ scheduler._step_index += 1
693
+
694
+ return prev_sample
695
+
696
+
697
+ def deterministic_ddpm_step(
698
+ model_output: torch.FloatTensor,
699
+ timestep: Union[float, torch.FloatTensor],
700
+ sample: torch.FloatTensor,
701
+ eta,
702
+ use_clipped_model_output,
703
+ generator,
704
+ variance_noise,
705
+ return_dict,
706
+ scheduler,
707
+ ):
708
+ """
709
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
710
+ process from the learned model outputs (most often the predicted noise).
711
+
712
+ Args:
713
+ model_output (`torch.FloatTensor`):
714
+ The direct output from learned diffusion model.
715
+ timestep (`float`):
716
+ The current discrete timestep in the diffusion chain.
717
+ sample (`torch.FloatTensor`):
718
+ A current instance of a sample created by the diffusion process.
719
+ generator (`torch.Generator`, *optional*):
720
+ A random number generator.
721
+ return_dict (`bool`, *optional*, defaults to `True`):
722
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
723
+
724
+ Returns:
725
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
726
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
727
+ tuple is returned where the first element is the sample tensor.
728
+
729
+ """
730
+ t = timestep
731
+
732
+ prev_t = scheduler.previous_timestep(t)
733
+
734
+ if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
735
+ "learned",
736
+ "learned_range",
737
+ ]:
738
+ model_output, predicted_variance = torch.split(
739
+ model_output, sample.shape[1], dim=1
740
+ )
741
+ else:
742
+ predicted_variance = None
743
+
744
+ # 1. compute alphas, betas
745
+ alpha_prod_t = scheduler.alphas_cumprod[t]
746
+ alpha_prod_t_prev = (
747
+ scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
748
+ )
749
+ beta_prod_t = 1 - alpha_prod_t
750
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
751
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
752
+ current_beta_t = 1 - current_alpha_t
753
+
754
+ # 2. compute predicted original sample from predicted noise also called
755
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
756
+ if scheduler.config.prediction_type == "epsilon":
757
+ pred_original_sample = (
758
+ sample - beta_prod_t ** (0.5) * model_output
759
+ ) / alpha_prod_t ** (0.5)
760
+ elif scheduler.config.prediction_type == "sample":
761
+ pred_original_sample = model_output
762
+ elif scheduler.config.prediction_type == "v_prediction":
763
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
764
+ beta_prod_t**0.5
765
+ ) * model_output
766
+ else:
767
+ raise ValueError(
768
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
769
+ " `v_prediction` for the DDPMScheduler."
770
+ )
771
+
772
+ # 3. Clip or threshold "predicted x_0"
773
+ if scheduler.config.thresholding:
774
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
775
+ elif scheduler.config.clip_sample:
776
+ pred_original_sample = pred_original_sample.clamp(
777
+ -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
778
+ )
779
+
780
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
781
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
782
+ pred_original_sample_coeff = (
783
+ alpha_prod_t_prev ** (0.5) * current_beta_t
784
+ ) / beta_prod_t
785
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
786
+
787
+ # 5. Compute predicted previous sample µ_t
788
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
789
+ pred_prev_sample = (
790
+ pred_original_sample_coeff * pred_original_sample
791
+ + current_sample_coeff * sample
792
+ )
793
+
794
+ return pred_prev_sample
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ml-collections
2
+ gradio
3
+ torch
4
+ torchvision
5
+ diffusers
6
+ transformers
7
+ accelerate
8
+ spaces
9
+ sentencepiece
10
+ compel
11
+ gfpgan
12
+ git+https://github.com/XPixelGroup/BasicSR@master
13
+ facexlib
14
+ realesrgan
run_configs/noise_shift_3_steps.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breakdown: "x_t_hat_c"
2
+ cross_r: 0.9
3
+ eta_reconstruct: 1
4
+ eta_retrieve: 1
5
+ max_norm_zs: [-1, -1, 15.5]
6
+ model: "stabilityai/sdxl-turbo"
7
+ noise_shift_delta: 1
8
+ noise_timesteps: [599, 299, 0]
9
+ timesteps: [799, 499, 199]
10
+ num_steps_inversion: 5
11
+ step_start: 1
12
+ real_cfg_scale: 0
13
+ real_cfg_scale_save: 0
14
+ scheduler_type: "ddpm"
15
+ seed: 2
16
+ self_r: 0.5
17
+ ws1: 1.5
18
+ ws2: 1
19
+ clean_step_timestep: 0
run_configs/noise_shift_guidance_1_5.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breakdown: "x_t_hat_c"
2
+ cross_r: 0.9
3
+ eta: 1
4
+ max_norm_zs: [-1, -1, -1, 15.5]
5
+ model: ""
6
+ noise_shift_delta: 1
7
+ noise_timesteps: null
8
+ num_steps_inversion: 20
9
+ step_start: 5
10
+ real_cfg_scale: 0
11
+ real_cfg_scale_save: 0
12
+ scheduler_type: "ddpm"
13
+ seed: 2
14
+ self_r: 0.5
15
+ timesteps: null
16
+ ws1: 1.5
17
+ ws2: 1
18
+ clean_step_timestep: 0