import torch import os import PIL from typing import List, Optional, Union from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput from PIL import Image from diffusers.utils import logging VECTOR_DATA_FOLDER = "vector_data" VECTOR_DATA_DICT = "vector_data" logger = logging.get_logger(__name__) def get_ddpm_inversion_scheduler( scheduler, step_function, config, timesteps, save_timesteps, latents, x_ts, x_ts_c_hat, save_intermediate_results, pipe, x_0, v1s_images, v2s_images, deltas_images, v1_x0s, v2_x0s, deltas_x0s, folder_name, image_name, time_measure_n, ): def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ): # if scheduler.is_save: # start = timer() res_inv = step_save_latents( scheduler, model_output[:1, :, :, :], timestep, sample[:1, :, :, :], eta, use_clipped_model_output, generator, variance_noise, return_dict, ) # end = timer() # print(f"Run Time Inv: {end - start}") res_inf = step_use_latents( scheduler, model_output[1:, :, :, :], timestep, sample[1:, :, :, :], eta, use_clipped_model_output, generator, variance_noise, return_dict, ) # res = res_inv res = (torch.cat((res_inv[0], res_inf[0]), dim=0),) return res # return res scheduler.step_function = step_function scheduler.is_save = True scheduler._timesteps = timesteps scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps scheduler._config = config scheduler.latents = latents scheduler.x_ts = x_ts scheduler.x_ts_c_hat = x_ts_c_hat scheduler.step = step scheduler.save_intermediate_results = save_intermediate_results scheduler.pipe = pipe scheduler.v1s_images = v1s_images scheduler.v2s_images = v2s_images scheduler.deltas_images = deltas_images scheduler.v1_x0s = v1_x0s scheduler.v2_x0s = v2_x0s scheduler.deltas_x0s = deltas_x0s scheduler.clean_step_run = False scheduler.x_0s = create_xts( config.noise_shift_delta, config.noise_timesteps, config.clean_step_timestep, None, pipe.scheduler, timesteps, x_0, no_add_noise=True, ) scheduler.folder_name = folder_name scheduler.image_name = image_name scheduler.p_to_p = False scheduler.p_to_p_replace = False scheduler.time_measure_n = time_measure_n return scheduler def step_save_latents( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ): # print(self._save_timesteps) # timestep_index = map_timpstep_to_index[timestep] # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item() timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1 next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1 u_hat_t = self.step_function( model_output=model_output, timestep=timestep, sample=sample, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator, variance_noise=variance_noise, return_dict=False, scheduler=self, ) x_t_minus_1 = self.x_ts[next_timestep_index] self.x_ts_c_hat.append(u_hat_t) z_t = x_t_minus_1 - u_hat_t self.latents.append(z_t) z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs) x_t_minus_1_predicted = u_hat_t + z_t if not return_dict: return (x_t_minus_1_predicted,) return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None) def step_use_latents( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ): # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item() timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1 next_timestep_index = ( timestep_index + 1 if not self.clean_step_run else -1 ) z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T _, normalize_coefficient = normalize( z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t, timestep_index, self._config.max_norm_zs, ) if normalize_coefficient == 0: eta = 0 # eta = normalize_coefficient x_t_hat_c_hat = self.step_function( model_output=model_output, timestep=timestep, sample=sample, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator, variance_noise=variance_noise, return_dict=False, scheduler=self, ) w1 = self._config.ws1[timestep_index] w2 = self._config.ws2[timestep_index] x_t_minus_1_exact = self.x_ts[next_timestep_index] x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat) x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index] if self._config.breakdown == "x_t_c_hat": raise NotImplementedError("breakdown x_t_c_hat not implemented yet") # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat) x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat) # if self._config.breakdown == "x_t_c_hat": # v1 = x_t_hat_c_hat - x_t_c_hat # v2 = x_t_c_hat - x_t_c if ( self._config.breakdown == "x_t_hat_c" or self._config.breakdown == "x_t_hat_c_with_zeros" ): zero_index_reconstruction = 1 if not self.time_measure_n else 0 edit_prompts_num = ( (model_output.size(0) - zero_index_reconstruction) // 3 if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p else (model_output.size(0) - zero_index_reconstruction) // 2 ) x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction) edit_images_indices = ( edit_prompts_num + zero_index_reconstruction, ( model_output.size(0) if self._config.breakdown == "x_t_hat_c" else zero_index_reconstruction + 2 * edit_prompts_num ), ) x_t_hat_c = torch.zeros_like(x_t_hat_c_hat) x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[ x_t_hat_c_indices[0] : x_t_hat_c_indices[1] ] v1 = x_t_hat_c_hat - x_t_hat_c v2 = x_t_hat_c - normalize_coefficient * x_t_c if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p: path = os.path.join( self.folder_name, VECTOR_DATA_FOLDER, self.image_name, ) if not hasattr(self, VECTOR_DATA_DICT): os.makedirs(path, exist_ok=True) self.vector_data = dict() x_t_0 = x_t_c_hat[1] empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num) x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]] self.vector_data[timestep.item()] = dict() self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[ edit_images_indices[0] : edit_images_indices[1] ] self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0 self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0) self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0) self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[ edit_images_indices[0] : edit_images_indices[1] ] self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[ 0 ].expand_as(x_t_hat_0) self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[ next_timestep_index ].expand_as(x_t_hat_0) else: # no breakdown v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c v2 = 0 if self.save_intermediate_results and not self.p_to_p: delta = v1 + v2 v1_plus_x0 = self.x_0s[next_timestep_index] + v1 v2_plus_x0 = self.x_0s[next_timestep_index] + v2 delta_plus_x0 = self.x_0s[next_timestep_index] + delta v1_images = decode_latents(v1, self.pipe) self.v1s_images.append(v1_images) v2_images = ( decode_latents(v2, self.pipe) if self._config.breakdown != "no_breakdown" else [PIL.Image.new("RGB", (1, 1))] ) self.v2s_images.append(v2_images) delta_images = decode_latents(delta, self.pipe) self.deltas_images.append(delta_images) v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe) self.v1_x0s.append(v1_plus_x0_images) v2_plus_x0_images = ( decode_latents(v2_plus_x0, self.pipe) if self._config.breakdown != "no_breakdown" else [PIL.Image.new("RGB", (1, 1))] ) self.v2_x0s.append(v2_plus_x0_images) delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe) self.deltas_x0s.append(delta_plus_x0_images) # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}") # if self._config.breakdown != "no_breakdown": # print(f"v2 norm: {torch.norm(v2, dim=0).mean()}") # print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}") x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2 if ( self._config.breakdown == "x_t_hat_c" or self._config.breakdown == "x_t_hat_c_with_zeros" ): x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[ edit_images_indices[0] : edit_images_indices[1] ] # update x_t_hat_c to be x_t_hat_c_hat if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p: x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = ( x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]] ) self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[ edit_images_indices[0] : edit_images_indices[1] ] if timestep == self._timesteps[-1]: torch.save( self.vector_data, os.path.join( path, f"{VECTOR_DATA_DICT}.pt", ), ) # p_to_p_force_perfect_reconstruction if not self.time_measure_n: x_t_minus_1[0] = x_t_minus_1_exact[0] if not return_dict: return (x_t_minus_1,) return DDIMSchedulerOutput( prev_sample=x_t_minus_1, pred_original_sample=None, ) def create_xts( noise_shift_delta, noise_timesteps, clean_step_timestep, generator, scheduler, timesteps, x_0, no_add_noise=False, ): if noise_timesteps is None: noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1]) noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps] first_x_0_idx = len(noise_timesteps) for i in range(len(noise_timesteps)): if noise_timesteps[i] <= 0: first_x_0_idx = i break noise_timesteps = noise_timesteps[:first_x_0_idx] x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1) noise = ( torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to( x_0.device ) if not no_add_noise else torch.zeros_like(x_0_expanded) ) x_ts = scheduler.add_noise( x_0_expanded, noise, torch.IntTensor(noise_timesteps), ) x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)] x_ts += [x_0] * (len(timesteps) - first_x_0_idx) x_ts += [x_0] if clean_step_timestep > 0: x_ts += [x_0] return x_ts def normalize( z_t, i, max_norm_zs, ): max_norm = max_norm_zs[i] if max_norm < 0: return z_t, 1 norm = torch.norm(z_t) if norm < max_norm: return z_t, 1 coeff = max_norm / norm z_t = z_t * coeff return z_t, coeff def decode_latents(latent, pipe): latent_img = pipe.vae.decode( latent / pipe.vae.config.scaling_factor, return_dict=False )[0] return pipe.image_processor.postprocess(latent_img, output_type="pil") def deterministic_ddim_step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, scheduler=None, ): if scheduler.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) prev_timestep = ( timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps ) # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t if scheduler.config.prediction_type == "epsilon": pred_original_sample = ( sample - beta_prod_t ** (0.5) * model_output ) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif scheduler.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = ( sample - alpha_prod_t ** (0.5) * pred_original_sample ) / beta_prod_t ** (0.5) elif scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - ( beta_prod_t**0.5 ) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if scheduler.config.thresholding: pred_original_sample = scheduler._threshold_sample(pred_original_sample) elif scheduler.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range, ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = ( sample - alpha_prod_t ** (0.5) * pred_original_sample ) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** ( 0.5 ) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = ( alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction ) return prev_sample def deterministic_euler_step( model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, eta, use_clipped_model_output, generator, variance_noise, return_dict, scheduler, ): """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if scheduler.step_index is None: scheduler._init_step_index(timestep) sigma = scheduler.sigmas[scheduler.step_index] # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if scheduler.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif scheduler.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( sample / (sigma**2 + 1) ) elif scheduler.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) sigma_from = scheduler.sigmas[scheduler.step_index] sigma_to = scheduler.sigmas[scheduler.step_index + 1] sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma dt = sigma_down - sigma prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one scheduler._step_index += 1 return prev_sample def deterministic_non_ancestral_euler_step( model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, scheduler=None, ): """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not scheduler.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if scheduler.step_index is None: scheduler._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) sigma = scheduler.sigmas[scheduler.step_index] gamma = ( min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 ) sigma_hat = sigma * (gamma + 1) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # NOTE: "original_sample" should not be an expected prediction_type but is left in for # backwards compatibility if ( scheduler.config.prediction_type == "original_sample" or scheduler.config.prediction_type == "sample" ): pred_original_sample = model_output elif scheduler.config.prediction_type == "epsilon": pred_original_sample = sample - sigma_hat * model_output elif scheduler.config.prediction_type == "v_prediction": # denoised = model_output * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( sample / (sigma**2 + 1) ) else: raise ValueError( f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma_hat dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one scheduler._step_index += 1 return prev_sample def deterministic_ddpm_step( model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, eta, use_clipped_model_output, generator, variance_noise, return_dict, scheduler, ): """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ t = timestep prev_t = scheduler.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [ "learned", "learned_range", ]: model_output, predicted_variance = torch.split( model_output, sample.shape[1], dim=1 ) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if scheduler.config.prediction_type == "epsilon": pred_original_sample = ( sample - beta_prod_t ** (0.5) * model_output ) / alpha_prod_t ** (0.5) elif scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - ( beta_prod_t**0.5 ) * model_output else: raise ValueError( f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if scheduler.config.thresholding: pred_original_sample = scheduler._threshold_sample(pred_original_sample) elif scheduler.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = ( alpha_prod_t_prev ** (0.5) * current_beta_t ) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = ( pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample ) return pred_prev_sample