# This code uses diffusers library from https://github.com/huggingface/diffusers import imp import numpy as np import torch import random from PIL import Image, ImageDraw, ImageFont import copy from typing import Optional, Union, Tuple, List, Callable, Dict, Any from tqdm.notebook import tqdm from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UpBlock2D, get_down_block, get_up_block, ) from diffusers.models.unet_2d_condition import UNet2DConditionOutput from copy import deepcopy import json import inspect import os import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils.torch_utils import is_compiled_module from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from tqdm import tqdm import time from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( UNetMidBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, get_down_block, get_up_block, ) # from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor # from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from typing import Any, Dict, List from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import CONFIG_NAME class PipelineCallback(ConfigMixin): config_name = CONFIG_NAME @register_to_config def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None): super().__init__() if (cutoff_step_ratio is None and cutoff_step_index is None) or ( cutoff_step_ratio is not None and cutoff_step_index is not None ): raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.") if cutoff_step_ratio is not None and ( not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0) ): raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.") @property def tensor_inputs(self) -> List[str]: raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}") def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]: raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}") def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: return self.callback_fn(pipeline, step_index, timestep, callback_kwargs) class MultiPipelineCallbacks: def __init__(self, callbacks: List[PipelineCallback]): self.callbacks = callbacks @property def tensor_inputs(self) -> List[str]: return [input for callback in self.callbacks for input in callback.tensor_inputs] def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: """ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs. """ for callback in self.callbacks: callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs) return callback_kwargs def seed_everything(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(seed) def get_promptls(prompt_path): with open(prompt_path) as f: prompt_ls = json.load(f) prompt_ls = [prompt['caption'].replace('/','_') for prompt in prompt_ls] return prompt_ls def warpped_feature(sample, step): """ sample: batch_size*dim*h*w, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size step: timestep span """ bs, dim, h, w = sample.shape uncond_fea, cond_fea = sample.chunk(2) uncond_fea = uncond_fea.repeat(step,1,1,1) # (step * bs//2) * dim * h *w cond_fea = cond_fea.repeat(step,1,1,1) # (step * bs//2) * dim * h *w return torch.cat([uncond_fea, cond_fea]) def warpped_skip_feature(block_samples, step): down_block_res_samples = [] for sample in block_samples: sample_expand = warpped_feature(sample, step) down_block_res_samples.append(sample_expand) return tuple(down_block_res_samples) def warpped_text_emb(text_emb, step): """ text_emb: batch_size*77*768, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size step: timestep span """ bs, token_len, dim = text_emb.shape uncond_fea, cond_fea = text_emb.chunk(2) uncond_fea = uncond_fea.repeat(step,1,1) # (step * bs//2) * 77 *768 cond_fea = cond_fea.repeat(step,1,1) # (step * bs//2) * 77 * 768 return torch.cat([uncond_fea, cond_fea]) # (step*bs) * 77 *768 def warpped_timestep(timesteps, bs): """ timestpes: list, such as [981, 961, 941] """ semi_bs = bs//2 ts = [] for timestep in timesteps: timestep = timestep[None] texp = timestep.expand(semi_bs) ts.append(texp) timesteps = torch.cat(ts) return timesteps.repeat(2,1).reshape(-1) def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg def register_parallel_pipeline_orig(pipe, mod='50ls4'): def new_call(self): @torch.no_grad() def call( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt (mostly unchanged from original) lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare timesteps (modified to handle parallel processing) self.scheduler.set_timesteps(num_inference_steps, device=device) #Simplified timestep setting timesteps = self.scheduler.timesteps # 5. Prepare latent variables (unchanged) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs (unchanged) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.5 Optionally get Guidance Scale Embedding (unchanged) timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Denoising loop (modified for parallel processing) init_latents = latents.detach().clone() all_steps = len(self.scheduler.timesteps) curr_step = 0 if mod == '50ls': cond = lambda timestep: timestep in [0, 1, 2, 3, 5, 10, 15, 25, 35] #Example condition, adjust as needed elif isinstance(mod, int): cond = lambda timestep: timestep % mod == 0 else: raise Exception("Currently not supported, But you can modify the code to customize the keytime") #added time_ids and add_text_embeds for original pipeline compatibility add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=pooled_prompt_embeds.shape[-1] if self.text_encoder_2 is None else self.text_encoder_2.config.projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=pooled_prompt_embeds.shape[-1] if self.text_encoder_2 is None else self.text_encoder_2.config.projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) add_text_embeds = pooled_prompt_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) while curr_step < all_steps: register_time(self.unet, curr_step) # Assuming register_time is defined elsewhere time_ls = [timesteps[curr_step]] curr_step += 1 while not cond(curr_step) and curr_step < all_steps: time_ls.append(timesteps[curr_step]) curr_step += 1 latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents #added kwargs for original pipeline compatibility added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, time_ls, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) bs = noise_pred.shape[0] bs_perstep = bs // len(time_ls) denoised_latent = latents for i, timestep in enumerate(time_ls): if timestep / 1000 < 0.5: denoised_latent = denoised_latent + 0.003 * init_latents curr_noise = noise_pred[i * bs_perstep : (i + 1) * bs_perstep] denoised_latent = self.scheduler.step( curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False )[0] latents = denoised_latent #rest of the code remains largely the same as the original pipeline if not output_type == "latent": needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): self.vae = self.vae.to(latents.dtype) has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) # Assuming this output type exists return call pipe.call = new_call(pipe) def register_faster_forward(model, mod = '50ls4'): def faster_forward(self): def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, # ADDED down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, # ADDED return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): print("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time if isinstance(timestep, list): timesteps = timestep[0] step = len(timestep) else: timesteps = timestep step = 1 if not torch.is_tensor(timesteps) and (not isinstance(timesteps,list)): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif (not isinstance(timesteps,list)) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) if (not isinstance(timesteps,list)) and len(timesteps.shape) == 1: # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) elif isinstance(timesteps, list): #timesteps list, such as [981,961,941] timesteps = warpped_timestep(timesteps, sample.shape[0]).to(sample.device) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) # class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) # if class_emb is not None: # if self.config.class_embeddings_concat: # emb = torch.cat([emb, class_emb], dim=-1) # else: # emb = emb + class_emb # aug_emb = self.get_aug_embed( # emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs # ) # encoder_hidden_states = self.process_encoder_hidden_states( # encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs # ) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) emb = emb + aug_emb # emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None: print("Encoder hid projections are not none") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) #=============== order = self.order #timestep, start by 0 #=============== ipow = int(np.sqrt(9 + 8*order)) cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] if isinstance(mod, int): cond = order % mod == 0 elif mod == "pro": cond = ipow * ipow == (9 + 8 * order) elif mod == "50ls": cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] elif mod == "50ls2": cond = order in [0, 10, 11, 12, 15, 20, 25, 30,35,45] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] elif mod == "50ls3": cond = order in [0, 20, 25, 30,35,45,46,47,48,49] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] elif mod == "50ls4": cond = order in [0, 9, 13, 14, 15, 28, 29, 32, 36,45] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] elif mod == "100ls": cond = order > 85 or order < 10 or order % 5 == 0 elif mod == "75ls": cond = order > 65 or order < 10 or order % 5 == 0 elif mod == "s2": cond = order < 20 or order > 40 or order % 2 == 0 if cond: # print('current timestep:', order) # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) print(f"Shape of sample at the start of downsample: {sample.shape}") for i, downsample_block in enumerate(self.down_blocks): print(f"\n--- Downsample Block {i+1} ---") print(f" Input shape: {sample.shape}") if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: print("has cross attention") additional_residuals = {} if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) print(f" Additional Residuals shape: {additional_residuals['additional_residuals'].shape}") sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, **additional_residuals ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) # if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0: # sample += down_intrablock_additional_residuals.pop(0) print(f" Output shape: {sample.shape}") for j, res_sample in enumerate(res_samples): print(f" Res sample {j+1} shape: {res_sample.shape}") down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # Handle ControlNet additional residuals if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = self.mid_block(sample, emb) #Handle T2I-Adapter-XL if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape: sample += down_intrablock_additional_residuals.pop(0) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual #----------------------save feature------------------------- # setattr(self, 'skip_feature', (tmp_sample.clone() for tmp_sample in down_block_res_samples)) setattr(self, 'skip_feature', deepcopy(down_block_res_samples)) setattr(self, 'toup_feature', sample.detach().clone()) #-----------------------save feature------------------------ #-------------------expand feature for parallel--------------- if isinstance(timestep, list): #timesteps list, such as [981,961,941] timesteps = warpped_timestep(timestep, sample.shape[0]).to(sample.device) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) sample = warpped_feature(sample, step) encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) #-------------------expand feature for parallel--------------- else: down_block_res_samples = self.skip_feature sample = self.toup_feature #-------------------expand feature for parallel--------------- down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) sample = warpped_feature(sample, step) encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) #-------------------expand feature for parallel--------------- # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) return forward if model.__class__.__name__ == 'UNet2DConditionModel': model.forward = faster_forward(model) def register_faster_orig_forward(model, mod = '50ls4'): def faster_orig_forward(self): def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. forward_upsample_size = True break if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) #=============== order = model.order #timestep, start by 0 #=============== ipow = int(np.sqrt(9 + 8*order)) cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] if isinstance(mod, int): cond = order % mod == 0 elif mod == "pro": cond = ipow * ipow == (9 + 8 * order) elif mod == "50ls": cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] elif mod == "50ls2": cond = order in [0, 10, 11, 12, 15, 20, 25, 30,35,45] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] elif mod == "50ls3": cond = order in [0, 20, 25, 30,35,45,46,47,48,49] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] elif mod == "50ls4": cond = order in [0, 9, 13, 14, 15, 28, 29, 32, 36,45] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] elif mod == "100ls": cond = order > 85 or order < 10 or order % 5 == 0 elif mod == "75ls": cond = order > 65 or order < 10 or order % 5 == 0 elif mod == "s2": cond = order < 20 or order > 40 or order % 2 == 0 if cond: # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if ( is_adapter and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape ): sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual #----------------------save feature------------------------- # setattr(self, 'skip_feature', (tmp_sample.clone() for tmp_sample in down_block_res_samples)) setattr(self, 'skip_feature', deepcopy(down_block_res_samples)) setattr(self, 'toup_feature', sample.detach().clone()) #-----------------------save feature------------------------ #-------------------expand feature for parallel--------------- if isinstance(timestep, list): step = len(timestep) #timesteps list, such as [981,961,941] timesteps = warpped_timestep(timestep, sample.shape[0]).to(sample.device) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) else: step = 1 down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) sample = warpped_feature(sample, step) encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) #-------------------expand feature for parallel--------------- else: step = 1 down_block_res_samples = self.skip_feature sample = self.toup_feature #-------------------expand feature for parallel--------------- down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) sample = warpped_feature(sample, step) encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) #-------------------expand feature for parallel--------------- # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, scale=lora_scale, ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) return forward if model.__class__.__name__ == 'UNet2DConditionModel': model.forward = faster_orig_forward(model) def register_time(unet, t): setattr(unet, 'order', t)