|
import torch |
|
|
|
from diffusers import ( |
|
DDPMScheduler, |
|
StableDiffusionXLImg2ImgPipeline, |
|
) |
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents |
|
from PIL import Image |
|
from inversion_utils import get_ddpm_inversion_scheduler, create_xts |
|
from config import get_config, get_num_steps_actual |
|
from functools import partial |
|
from compel import Compel, ReturnedEmbeddingsType |
|
|
|
class Object(object): |
|
pass |
|
|
|
args = Object() |
|
args.images_paths = None |
|
args.images_folder = None |
|
args.force_use_cpu = False |
|
args.folder_name = 'test_measure_time' |
|
args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml' |
|
args.save_intermediate_results = False |
|
args.batch_size = None |
|
args.skip_p_to_p = True |
|
args.only_p_to_p = False |
|
args.fp16 = False |
|
args.prompts_file = 'dataset_measure_time/dataset.json' |
|
args.images_in_prompts_file = None |
|
args.seed = 986 |
|
args.time_measure_n = 1 |
|
|
|
|
|
assert ( |
|
args.batch_size is None or args.save_intermediate_results is False |
|
), "save_intermediate_results is not implemented for batch_size > 1" |
|
|
|
generator = None |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
BASE_MODEL = "stabilityai/sdxl-turbo" |
|
|
|
|
|
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
BASE_MODEL, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
use_safetensors=True, |
|
) |
|
pipeline = pipeline.to(device) |
|
|
|
pipeline.scheduler = DDPMScheduler.from_pretrained( |
|
BASE_MODEL, |
|
subfolder="scheduler", |
|
) |
|
|
|
config = get_config(args) |
|
|
|
compel_proc = Compel( |
|
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , |
|
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], |
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, |
|
requires_pooled=[False, True] |
|
) |
|
|
|
def run( |
|
input_image:Image, |
|
src_prompt:str, |
|
tgt_prompt:str, |
|
seed:int, |
|
w1:float, |
|
w2:float, |
|
num_steps:int, |
|
start_step:int, |
|
guidance_scale:float, |
|
): |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
config.num_steps_inversion = num_steps |
|
config.step_start = start_step |
|
num_steps_actual = get_num_steps_actual(config) |
|
|
|
|
|
num_steps_inversion = config.num_steps_inversion |
|
denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion |
|
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
pipeline.scheduler, num_steps_inversion, device, None |
|
) |
|
timesteps, num_inference_steps = pipeline.get_timesteps( |
|
num_inference_steps=num_inference_steps, |
|
denoising_start=denoising_start, |
|
strength=0, |
|
device=device, |
|
) |
|
timesteps = timesteps.type(torch.int64) |
|
|
|
timesteps = [torch.tensor(t) for t in timesteps.tolist()] |
|
timesteps_len = len(timesteps) |
|
config.step_start = start_step + num_steps_actual - timesteps_len |
|
num_steps_actual = timesteps_len |
|
config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5] |
|
|
|
pipeline.__call__ = partial( |
|
pipeline.__call__, |
|
num_inference_steps=num_steps_inversion, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
denoising_start=denoising_start, |
|
strength=0, |
|
) |
|
|
|
x_0_image = input_image |
|
x_0 = encode_image(x_0_image, pipeline) |
|
x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False) |
|
x_ts = [xt.to(dtype=torch.float16) for xt in x_ts] |
|
latents = [x_ts[0]] |
|
x_ts_c_hat = [None] |
|
config.ws1 = [w1] * num_steps_actual |
|
config.ws2 = [w2] * num_steps_actual |
|
pipeline.scheduler = get_ddpm_inversion_scheduler( |
|
pipeline.scheduler, |
|
config.step_function, |
|
config, |
|
timesteps, |
|
config.save_timesteps, |
|
latents, |
|
x_ts, |
|
x_ts_c_hat, |
|
args.save_intermediate_results, |
|
pipeline, |
|
x_0, |
|
v1s_images := [], |
|
v2s_images := [], |
|
deltas_images := [], |
|
v1_x0s := [], |
|
v2_x0s := [], |
|
deltas_x0s := [], |
|
"res12", |
|
image_name="im_name", |
|
time_measure_n=args.time_measure_n, |
|
) |
|
latent = latents[0].expand(3, -1, -1, -1) |
|
prompt = [src_prompt, src_prompt, tgt_prompt] |
|
conditioning, pooled = compel_proc(prompt) |
|
image = pipeline.__call__( |
|
image=latent, |
|
prompt_embeds=conditioning, |
|
pooled_prompt_embeds=pooled, |
|
eta=1, |
|
).images |
|
return image[2] |
|
|
|
def encode_image(image, pipe): |
|
image = pipe.image_processor.preprocess(image) |
|
originDtype = pipe.dtype |
|
image = image.to(device=device, dtype=originDtype) |
|
|
|
if pipe.vae.config.force_upcast: |
|
image = image.float() |
|
pipe.vae.to(dtype=torch.float32) |
|
|
|
if isinstance(generator, list): |
|
init_latents = [ |
|
retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i]) |
|
for i in range(1) |
|
] |
|
init_latents = torch.cat(init_latents, dim=0) |
|
else: |
|
init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator) |
|
|
|
if pipe.vae.config.force_upcast: |
|
pipe.vae.to(originDtype) |
|
|
|
init_latents = init_latents.to(originDtype) |
|
init_latents = pipe.vae.config.scaling_factor * init_latents |
|
|
|
return init_latents.to(dtype=torch.float16) |
|
|
|
def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None): |
|
|
|
if denoising_start is None: |
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
|
t_start = max(num_inference_steps - init_timestep, 0) |
|
else: |
|
t_start = 0 |
|
|
|
timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :] |
|
|
|
|
|
|
|
if denoising_start is not None: |
|
discrete_timestep_cutoff = int( |
|
round( |
|
pipe.scheduler.config.num_train_timesteps |
|
- (denoising_start * pipe.scheduler.config.num_train_timesteps) |
|
) |
|
) |
|
|
|
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() |
|
if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_inference_steps = num_inference_steps + 1 |
|
|
|
|
|
timesteps = timesteps[-num_inference_steps:] |
|
return timesteps, num_inference_steps |
|
|
|
return timesteps, num_inference_steps - t_start |
|
|