|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
import inspect |
|
import math |
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from packaging import version |
|
from PIL import Image |
|
from torch import nn |
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast |
|
|
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
|
from diffusers.configuration_utils import ConfigMixin, FrozenDict, LegacyConfigMixin, register_to_config |
|
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
|
from diffusers.loaders import ( |
|
FromSingleFileMixin, |
|
IPAdapterMixin, |
|
PeftAdapterMixin, |
|
StableDiffusionLoraLoaderMixin, |
|
TextualInversionLoaderMixin, |
|
UNet2DConditionLoadersMixin, |
|
) |
|
from diffusers.loaders.single_file_model import FromOriginalModelMixin |
|
from diffusers.models.activations import GELU, get_activation |
|
from diffusers.models.attention_processor import ( |
|
ADDED_KV_ATTENTION_PROCESSORS, |
|
CROSS_ATTENTION_PROCESSORS, |
|
Attention, |
|
AttentionProcessor, |
|
AttnAddedKVProcessor, |
|
AttnProcessor, |
|
FusedAttnProcessor2_0, |
|
) |
|
from diffusers.models.downsampling import Downsample2D |
|
from diffusers.models.embeddings import ( |
|
GaussianFourierProjection, |
|
GLIGENTextBoundingboxProjection, |
|
ImageHintTimeEmbedding, |
|
ImageProjection, |
|
ImageTimeEmbedding, |
|
TextImageProjection, |
|
TextImageTimeEmbedding, |
|
TextTimeEmbedding, |
|
TimestepEmbedding, |
|
Timesteps, |
|
) |
|
from diffusers.models.lora import adjust_lora_scale_text_encoder |
|
from diffusers.models.modeling_utils import LegacyModelMixin, ModelMixin |
|
from diffusers.models.resnet import ResnetBlock2D |
|
from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D |
|
from diffusers.models.upsampling import Upsample2D |
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin |
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin |
|
from diffusers.utils import ( |
|
USE_PEFT_BACKEND, |
|
BaseOutput, |
|
deprecate, |
|
is_torch_version, |
|
is_torch_xla_available, |
|
logging, |
|
replace_example_docstring, |
|
scale_lora_layers, |
|
unscale_lora_layers, |
|
) |
|
from diffusers.utils.torch_utils import apply_freeu, randn_tensor |
|
|
|
|
|
if is_torch_xla_available(): |
|
import torch_xla.core.xla_model as xm |
|
|
|
XLA_AVAILABLE = True |
|
else: |
|
XLA_AVAILABLE = False |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
EXAMPLE_DOC_STRING = """ |
|
Examples: |
|
```py |
|
>>> from diffusers import DiffusionPipeline |
|
>>> from diffusers.utils import make_image_grid |
|
|
|
>>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64 |
|
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", |
|
>>> custom_pipeline="matryoshka").to("cuda") |
|
|
|
>>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" |
|
>>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed" |
|
>>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy" |
|
>>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images |
|
>>> make_image_grid(image, rows=1, cols=len(image)) |
|
|
|
>>> pipe.change_nesting_level(<int>) # 0, 1, or 2 |
|
>>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively. |
|
``` |
|
""" |
|
|
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
|
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: Optional[int] = None, |
|
device: Optional[Union[str, torch.device]] = None, |
|
timesteps: Optional[List[int]] = None, |
|
sigmas: Optional[List[float]] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
`num_inference_steps` and `sigmas` must be `None`. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None and sigmas is not None: |
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
elif sigmas is not None: |
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accept_sigmas: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" sigmas schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
|
|
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): |
|
|
|
if hidden_states.shape[chunk_dim] % chunk_size != 0: |
|
raise ValueError( |
|
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." |
|
) |
|
|
|
num_chunks = hidden_states.shape[chunk_dim] // chunk_size |
|
ff_output = torch.cat( |
|
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], |
|
dim=chunk_dim, |
|
) |
|
return ff_output |
|
|
|
|
|
@dataclass |
|
class MatryoshkaDDIMSchedulerOutput(BaseOutput): |
|
""" |
|
Output class for the scheduler's `step` function output. |
|
|
|
Args: |
|
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the |
|
denoising loop. |
|
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. |
|
`pred_original_sample` can be used to preview progress or for guidance. |
|
""" |
|
|
|
prev_sample: Union[torch.Tensor, List[torch.Tensor]] |
|
pred_original_sample: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None |
|
|
|
|
|
|
|
def betas_for_alpha_bar( |
|
num_diffusion_timesteps, |
|
max_beta=0.999, |
|
alpha_transform_type="cosine", |
|
): |
|
""" |
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of |
|
(1-beta) over time from t = [0,1]. |
|
|
|
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up |
|
to that part of the diffusion process. |
|
|
|
|
|
Args: |
|
num_diffusion_timesteps (`int`): the number of betas to produce. |
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to |
|
prevent singularities. |
|
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. |
|
Choose from `cosine` or `exp` |
|
|
|
Returns: |
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs |
|
""" |
|
if alpha_transform_type == "cosine": |
|
|
|
def alpha_bar_fn(t): |
|
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 |
|
|
|
elif alpha_transform_type == "exp": |
|
|
|
def alpha_bar_fn(t): |
|
return math.exp(t * -12.0) |
|
|
|
else: |
|
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") |
|
|
|
betas = [] |
|
for i in range(num_diffusion_timesteps): |
|
t1 = i / num_diffusion_timesteps |
|
t2 = (i + 1) / num_diffusion_timesteps |
|
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) |
|
return torch.tensor(betas, dtype=torch.float32) |
|
|
|
|
|
|
|
def rescale_zero_terminal_snr(betas): |
|
""" |
|
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) |
|
|
|
|
|
Args: |
|
betas (`torch.Tensor`): |
|
the betas that the scheduler is being initialized with. |
|
|
|
Returns: |
|
`torch.Tensor`: rescaled betas with zero terminal SNR |
|
""" |
|
|
|
alphas = 1.0 - betas |
|
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
alphas_bar_sqrt = alphas_cumprod.sqrt() |
|
|
|
|
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() |
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() |
|
|
|
|
|
alphas_bar_sqrt -= alphas_bar_sqrt_T |
|
|
|
|
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) |
|
|
|
|
|
alphas_bar = alphas_bar_sqrt**2 |
|
alphas = alphas_bar[1:] / alphas_bar[:-1] |
|
alphas = torch.cat([alphas_bar[0:1], alphas]) |
|
betas = 1 - alphas |
|
|
|
return betas |
|
|
|
|
|
class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin): |
|
""" |
|
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with |
|
non-Markovian guidance. |
|
|
|
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic |
|
methods the library implements for all schedulers such as loading and saving. |
|
|
|
Args: |
|
num_train_timesteps (`int`, defaults to 1000): |
|
The number of diffusion steps to train the model. |
|
beta_start (`float`, defaults to 0.0001): |
|
The starting `beta` value of inference. |
|
beta_end (`float`, defaults to 0.02): |
|
The final `beta` value. |
|
beta_schedule (`str`, defaults to `"linear"`): |
|
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from |
|
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. |
|
trained_betas (`np.ndarray`, *optional*): |
|
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. |
|
clip_sample (`bool`, defaults to `True`): |
|
Clip the predicted sample for numerical stability. |
|
clip_sample_range (`float`, defaults to 1.0): |
|
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. |
|
set_alpha_to_one (`bool`, defaults to `True`): |
|
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step |
|
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, |
|
otherwise it uses the alpha value at step 0. |
|
steps_offset (`int`, defaults to 0): |
|
An offset added to the inference steps, as required by some model families. |
|
prediction_type (`str`, defaults to `epsilon`, *optional*): |
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), |
|
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen |
|
Video](https://imagen.research.google/video/paper.pdf) paper). |
|
thresholding (`bool`, defaults to `False`): |
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such |
|
as Stable Diffusion. |
|
dynamic_thresholding_ratio (`float`, defaults to 0.995): |
|
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. |
|
sample_max_value (`float`, defaults to 1.0): |
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True`. |
|
timestep_spacing (`str`, defaults to `"leading"`): |
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. |
|
rescale_betas_zero_snr (`bool`, defaults to `False`): |
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and |
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to |
|
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). |
|
""" |
|
|
|
order = 1 |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_train_timesteps: int = 1000, |
|
beta_start: float = 0.0001, |
|
beta_end: float = 0.02, |
|
beta_schedule: str = "linear", |
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, |
|
clip_sample: bool = True, |
|
set_alpha_to_one: bool = True, |
|
steps_offset: int = 0, |
|
prediction_type: str = "epsilon", |
|
thresholding: bool = False, |
|
dynamic_thresholding_ratio: float = 0.995, |
|
clip_sample_range: float = 1.0, |
|
sample_max_value: float = 1.0, |
|
timestep_spacing: str = "leading", |
|
rescale_betas_zero_snr: bool = False, |
|
): |
|
if trained_betas is not None: |
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32) |
|
elif beta_schedule == "linear": |
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) |
|
elif beta_schedule == "scaled_linear": |
|
|
|
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 |
|
elif beta_schedule == "squaredcos_cap_v2": |
|
if self.config.timestep_spacing == "matryoshka_style": |
|
self.betas = torch.cat((torch.tensor([0]), betas_for_alpha_bar(num_train_timesteps))) |
|
else: |
|
|
|
self.betas = betas_for_alpha_bar(num_train_timesteps) |
|
else: |
|
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") |
|
|
|
|
|
if rescale_betas_zero_snr: |
|
self.betas = rescale_zero_terminal_snr(self.betas) |
|
|
|
self.alphas = 1.0 - self.betas |
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] |
|
|
|
|
|
self.init_noise_sigma = 1.0 |
|
|
|
|
|
self.num_inference_steps = None |
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) |
|
|
|
self.scales = None |
|
self.schedule_shifted_power = 1.0 |
|
|
|
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: |
|
""" |
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the |
|
current timestep. |
|
|
|
Args: |
|
sample (`torch.Tensor`): |
|
The input sample. |
|
timestep (`int`, *optional*): |
|
The current timestep in the diffusion chain. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
A scaled input sample. |
|
""" |
|
return sample |
|
|
|
def _get_variance(self, timestep, prev_timestep): |
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod |
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
|
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) |
|
|
|
return variance |
|
|
|
|
|
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: |
|
""" |
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the |
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by |
|
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing |
|
pixels from saturation at each step. We find that dynamic thresholding results in significantly better |
|
photorealism as well as better image-text alignment, especially when using very large guidance weights." |
|
|
|
https://arxiv.org/abs/2205.11487 |
|
""" |
|
dtype = sample.dtype |
|
batch_size, channels, *remaining_dims = sample.shape |
|
|
|
if dtype not in (torch.float32, torch.float64): |
|
sample = sample.float() |
|
|
|
|
|
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) |
|
|
|
abs_sample = sample.abs() |
|
|
|
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) |
|
s = torch.clamp( |
|
s, min=1, max=self.config.sample_max_value |
|
) |
|
s = s.unsqueeze(1) |
|
sample = torch.clamp(sample, -s, s) / s |
|
|
|
sample = sample.reshape(batch_size, channels, *remaining_dims) |
|
sample = sample.to(dtype) |
|
|
|
return sample |
|
|
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): |
|
""" |
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference). |
|
|
|
Args: |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. |
|
""" |
|
|
|
if num_inference_steps > self.config.num_train_timesteps: |
|
raise ValueError( |
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" |
|
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" |
|
f" maximal {self.config.num_train_timesteps} timesteps." |
|
) |
|
|
|
self.num_inference_steps = num_inference_steps |
|
|
|
|
|
if self.config.timestep_spacing == "linspace": |
|
timesteps = ( |
|
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) |
|
.round()[::-1] |
|
.copy() |
|
.astype(np.int64) |
|
) |
|
elif self.config.timestep_spacing == "leading": |
|
step_ratio = self.config.num_train_timesteps // self.num_inference_steps |
|
|
|
|
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) |
|
timesteps += self.config.steps_offset |
|
elif self.config.timestep_spacing == "trailing": |
|
step_ratio = self.config.num_train_timesteps / self.num_inference_steps |
|
|
|
|
|
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) |
|
timesteps -= 1 |
|
elif self.config.timestep_spacing == "matryoshka_style": |
|
step_ratio = (self.config.num_train_timesteps + 1) / (num_inference_steps + 1) |
|
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1].copy().astype(np.int64) |
|
else: |
|
raise ValueError( |
|
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." |
|
) |
|
|
|
self.timesteps = torch.from_numpy(timesteps).to(device) |
|
|
|
def get_schedule_shifted(self, alpha_prod, scale_factor=None): |
|
if (scale_factor is not None) and (scale_factor > 1): |
|
scale_factor = scale_factor ** self.schedule_shifted_power |
|
snr = alpha_prod / (1 - alpha_prod) |
|
scaled_snr = snr / scale_factor |
|
alpha_prod = 1 / (1 + 1 / scaled_snr) |
|
return alpha_prod |
|
|
|
def step( |
|
self, |
|
model_output: torch.Tensor, |
|
timestep: int, |
|
sample: torch.Tensor, |
|
eta: float = 0.0, |
|
use_clipped_model_output: bool = False, |
|
generator=None, |
|
variance_noise: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[MatryoshkaDDIMSchedulerOutput, Tuple]: |
|
""" |
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
|
process from the learned model outputs (most often the predicted noise). |
|
|
|
Args: |
|
model_output (`torch.Tensor`): |
|
The direct output from learned diffusion model. |
|
timestep (`float`): |
|
The current discrete timestep in the diffusion chain. |
|
sample (`torch.Tensor`): |
|
A current instance of a sample created by the diffusion process. |
|
eta (`float`): |
|
The weight of noise for added noise in diffusion step. |
|
use_clipped_model_output (`bool`, defaults to `False`): |
|
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary |
|
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no |
|
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and |
|
`use_clipped_model_output` has no effect. |
|
generator (`torch.Generator`, *optional*): |
|
A random number generator. |
|
variance_noise (`torch.Tensor`): |
|
Alternative to generating noise with `generator` by directly providing the noise for the variance |
|
itself. Useful for methods such as [`CycleDiffusion`]. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. |
|
|
|
Returns: |
|
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: |
|
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a |
|
tuple is returned where the first element is the sample tensor. |
|
|
|
""" |
|
if self.num_inference_steps is None: |
|
raise ValueError( |
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.timestep_spacing != "matryoshka_style": |
|
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps |
|
else: |
|
prev_timestep = self.timesteps[torch.nonzero(self.timesteps == timestep).item() + 1] |
|
|
|
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod |
|
|
|
if self.config.timestep_spacing == "matryoshka_style" and len(model_output) > 1: |
|
alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in self.scales]) |
|
alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in self.scales]) |
|
|
|
beta_prod_t = 1 - alpha_prod_t |
|
|
|
|
|
|
|
if self.config.prediction_type == "epsilon": |
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
|
pred_epsilon = model_output |
|
elif self.config.prediction_type == "sample": |
|
pred_original_sample = model_output |
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) |
|
elif self.config.prediction_type == "v_prediction": |
|
if len(model_output) > 1: |
|
pred_original_sample = [] |
|
pred_epsilon = [] |
|
for m_o, s, a_p_t, b_p_t in zip(model_output, sample, alpha_prod_t, beta_prod_t): |
|
pred_original_sample.append((a_p_t**0.5) * s - (b_p_t**0.5) * m_o) |
|
pred_epsilon.append((a_p_t**0.5) * m_o + (b_p_t**0.5) * s) |
|
else: |
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output |
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
|
else: |
|
raise ValueError( |
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" |
|
" `v_prediction`" |
|
) |
|
|
|
|
|
if self.config.thresholding: |
|
if len(model_output) > 1: |
|
pred_original_sample = [ |
|
self._threshold_sample(p_o_s) |
|
for p_o_s in pred_original_sample |
|
] |
|
else: |
|
pred_original_sample = self._threshold_sample(pred_original_sample) |
|
elif self.config.clip_sample: |
|
if len(model_output) > 1: |
|
pred_original_sample = [ |
|
p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) |
|
for p_o_s in pred_original_sample |
|
] |
|
else: |
|
pred_original_sample = pred_original_sample.clamp( |
|
-self.config.clip_sample_range, self.config.clip_sample_range |
|
) |
|
|
|
|
|
|
|
variance = self._get_variance(timestep, prev_timestep) |
|
std_dev_t = eta * variance ** (0.5) |
|
|
|
if use_clipped_model_output: |
|
|
|
if len(model_output) > 1: |
|
pred_epsilon = [] |
|
for s, a_p_t, p_o_s, b_p_t in zip(sample, alpha_prod_t, pred_original_sample, beta_prod_t): |
|
pred_epsilon.append((s - a_p_t ** (0.5) * p_o_s) / b_p_t ** (0.5)) |
|
else: |
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) |
|
|
|
|
|
if len(model_output) > 1: |
|
pred_sample_direction = [] |
|
for p_e, a_p_t_p in zip(pred_epsilon, alpha_prod_t_prev): |
|
pred_sample_direction.append((1 - a_p_t_p - std_dev_t**2) ** (0.5) * p_e) |
|
else: |
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon |
|
|
|
|
|
if len(model_output) > 1: |
|
prev_sample = [] |
|
for p_o_s, p_s_d, a_p_t_p in zip(pred_original_sample, pred_sample_direction, alpha_prod_t_prev): |
|
prev_sample.append(a_p_t_p ** (0.5) * p_o_s + p_s_d) |
|
else: |
|
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction |
|
|
|
if eta > 0: |
|
if variance_noise is not None and generator is not None: |
|
raise ValueError( |
|
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or" |
|
" `variance_noise` stays `None`." |
|
) |
|
|
|
if variance_noise is None: |
|
if len(model_output) > 1: |
|
variance_noise = [] |
|
for m_o in model_output: |
|
variance_noise.append( |
|
randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype) |
|
) |
|
else: |
|
variance_noise = randn_tensor( |
|
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype |
|
) |
|
if len(model_output) > 1: |
|
prev_sample = [p_s + std_dev_t * v_n for v_n, p_s in zip(variance_noise, prev_sample)] |
|
else: |
|
variance = std_dev_t * variance_noise |
|
|
|
prev_sample = prev_sample + variance |
|
|
|
if not return_dict: |
|
return (prev_sample,) |
|
|
|
return MatryoshkaDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) |
|
|
|
|
|
def add_noise( |
|
self, |
|
original_samples: torch.Tensor, |
|
noise: torch.Tensor, |
|
timesteps: torch.IntTensor, |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) |
|
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) |
|
timesteps = timesteps.to(original_samples.device) |
|
|
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 |
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten() |
|
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): |
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) |
|
|
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 |
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() |
|
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): |
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) |
|
|
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise |
|
return noisy_samples |
|
|
|
|
|
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: |
|
|
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) |
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) |
|
timesteps = timesteps.to(sample.device) |
|
|
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 |
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten() |
|
while len(sqrt_alpha_prod.shape) < len(sample.shape): |
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) |
|
|
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 |
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() |
|
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): |
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) |
|
|
|
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample |
|
return velocity |
|
|
|
def __len__(self): |
|
return self.config.num_train_timesteps |
|
|
|
|
|
class CrossAttnDownBlock2D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
temb_channels: int, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
norm_type: str = "layer_norm", |
|
num_attention_heads: int = 1, |
|
cross_attention_dim: int = 1280, |
|
cross_attention_norm: Optional[str] = None, |
|
output_scale_factor: float = 1.0, |
|
downsample_padding: int = 1, |
|
add_downsample: bool = True, |
|
dual_cross_attention: bool = False, |
|
use_linear_projection: bool = False, |
|
only_cross_attention: bool = False, |
|
upcast_attention: bool = False, |
|
attention_type: str = "default", |
|
attention_pre_only: bool = False, |
|
attention_bias: bool = False, |
|
use_attention_ffn: bool = True, |
|
): |
|
super().__init__() |
|
resnets = [] |
|
attentions = [] |
|
|
|
self.has_cross_attention = True |
|
self.num_attention_heads = num_attention_heads |
|
if isinstance(transformer_layers_per_block, int): |
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers |
|
|
|
for i in range(num_layers): |
|
in_channels = in_channels if i == 0 else out_channels |
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
attentions.append( |
|
MatryoshkaTransformer2DModel( |
|
num_attention_heads, |
|
out_channels // num_attention_heads, |
|
in_channels=out_channels, |
|
num_layers=transformer_layers_per_block[i], |
|
cross_attention_dim=cross_attention_dim, |
|
upcast_attention=upcast_attention, |
|
use_attention_ffn=use_attention_ffn, |
|
) |
|
) |
|
self.attentions = nn.ModuleList(attentions) |
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
if add_downsample: |
|
self.downsamplers = nn.ModuleList( |
|
[ |
|
Downsample2D( |
|
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" |
|
) |
|
] |
|
) |
|
else: |
|
self.downsamplers = None |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
temb: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
additional_residuals: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
output_states = () |
|
|
|
blocks = list(zip(self.resnets, self.attentions)) |
|
|
|
for i, (resnet, attn) in enumerate(blocks): |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), |
|
hidden_states, |
|
temb, |
|
**ckpt_kwargs, |
|
) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
else: |
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if i == len(blocks) - 1 and additional_residuals is not None: |
|
hidden_states = hidden_states + additional_residuals |
|
|
|
output_states = output_states + (hidden_states,) |
|
|
|
if self.downsamplers is not None: |
|
for downsampler in self.downsamplers: |
|
hidden_states = downsampler(hidden_states) |
|
|
|
output_states = output_states + (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
|
|
class UNetMidBlock2DCrossAttn(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
temb_channels: int, |
|
out_channels: Optional[int] = None, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_groups_out: Optional[int] = None, |
|
resnet_pre_norm: bool = True, |
|
norm_type: str = "layer_norm", |
|
num_attention_heads: int = 1, |
|
output_scale_factor: float = 1.0, |
|
cross_attention_dim: int = 1280, |
|
cross_attention_norm: Optional[str] = None, |
|
dual_cross_attention: bool = False, |
|
use_linear_projection: bool = False, |
|
upcast_attention: bool = False, |
|
attention_type: str = "default", |
|
attention_pre_only: bool = False, |
|
attention_bias: bool = False, |
|
use_attention_ffn: bool = True, |
|
): |
|
super().__init__() |
|
|
|
out_channels = out_channels or in_channels |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
|
|
self.has_cross_attention = True |
|
self.num_attention_heads = num_attention_heads |
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
|
|
|
|
|
if isinstance(transformer_layers_per_block, int): |
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers |
|
|
|
resnet_groups_out = resnet_groups_out or resnet_groups |
|
|
|
|
|
resnets = [ |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
groups_out=resnet_groups_out, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
] |
|
attentions = [] |
|
|
|
for i in range(num_layers): |
|
attentions.append( |
|
MatryoshkaTransformer2DModel( |
|
num_attention_heads, |
|
out_channels // num_attention_heads, |
|
in_channels=out_channels, |
|
num_layers=transformer_layers_per_block[i], |
|
cross_attention_dim=cross_attention_dim, |
|
upcast_attention=upcast_attention, |
|
use_attention_ffn=use_attention_ffn, |
|
) |
|
) |
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups_out, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
|
|
self.attentions = nn.ModuleList(attentions) |
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
temb: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
hidden_states = self.resnets[0](hidden_states, temb) |
|
for attn, resnet in zip(self.attentions, self.resnets[1:]): |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), |
|
hidden_states, |
|
temb, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
hidden_states = resnet(hidden_states, temb) |
|
|
|
return hidden_states |
|
|
|
|
|
class CrossAttnUpBlock2D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
prev_output_channel: int, |
|
temb_channels: int, |
|
resolution_idx: Optional[int] = None, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
norm_type: str = "layer_norm", |
|
num_attention_heads: int = 1, |
|
cross_attention_dim: int = 1280, |
|
cross_attention_norm: Optional[str] = None, |
|
output_scale_factor: float = 1.0, |
|
add_upsample: bool = True, |
|
dual_cross_attention: bool = False, |
|
use_linear_projection: bool = False, |
|
only_cross_attention: bool = False, |
|
upcast_attention: bool = False, |
|
attention_type: str = "default", |
|
attention_pre_only: bool = False, |
|
attention_bias: bool = False, |
|
use_attention_ffn: bool = True, |
|
): |
|
super().__init__() |
|
resnets = [] |
|
attentions = [] |
|
|
|
self.has_cross_attention = True |
|
self.num_attention_heads = num_attention_heads |
|
|
|
if isinstance(transformer_layers_per_block, int): |
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers |
|
|
|
for i in range(num_layers): |
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels |
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=resnet_in_channels + res_skip_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
attentions.append( |
|
MatryoshkaTransformer2DModel( |
|
num_attention_heads, |
|
out_channels // num_attention_heads, |
|
in_channels=out_channels, |
|
num_layers=transformer_layers_per_block[i], |
|
cross_attention_dim=cross_attention_dim, |
|
upcast_attention=upcast_attention, |
|
use_attention_ffn=use_attention_ffn, |
|
) |
|
) |
|
self.attentions = nn.ModuleList(attentions) |
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
if add_upsample: |
|
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) |
|
else: |
|
self.upsamplers = None |
|
|
|
self.gradient_checkpointing = False |
|
self.resolution_idx = resolution_idx |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
res_hidden_states_tuple: Tuple[torch.Tensor, ...], |
|
temb: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
upsample_size: Optional[int] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
is_freeu_enabled = ( |
|
getattr(self, "s1", None) |
|
and getattr(self, "s2", None) |
|
and getattr(self, "b1", None) |
|
and getattr(self, "b2", None) |
|
) |
|
|
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
|
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
|
|
|
if is_freeu_enabled: |
|
hidden_states, res_hidden_states = apply_freeu( |
|
self.resolution_idx, |
|
hidden_states, |
|
res_hidden_states, |
|
s1=self.s1, |
|
s2=self.s2, |
|
b1=self.b1, |
|
b2=self.b2, |
|
) |
|
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), |
|
hidden_states, |
|
temb, |
|
**ckpt_kwargs, |
|
) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
else: |
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
|
|
if self.upsamplers is not None: |
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states, upsample_size) |
|
|
|
return hidden_states |
|
|
|
|
|
@dataclass |
|
class MatryoshkaTransformer2DModelOutput(BaseOutput): |
|
""" |
|
The output of [`MatryoshkaTransformer2DModel`]. |
|
|
|
Args: |
|
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`MatryoshkaTransformer2DModel`] is discrete): |
|
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability |
|
distributions for the unnoised latent pixels. |
|
""" |
|
|
|
sample: "torch.Tensor" |
|
|
|
|
|
class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): |
|
_supports_gradient_checkpointing = True |
|
_no_split_modules = ["MatryoshkaTransformerBlock"] |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_attention_heads: int = 16, |
|
attention_head_dim: int = 88, |
|
in_channels: Optional[int] = None, |
|
num_layers: int = 1, |
|
cross_attention_dim: Optional[int] = None, |
|
upcast_attention: bool = False, |
|
use_attention_ffn: bool = True, |
|
): |
|
super().__init__() |
|
self.in_channels = self.config.num_attention_heads * self.config.attention_head_dim |
|
self.gradient_checkpointing = False |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
MatryoshkaTransformerBlock( |
|
self.in_channels, |
|
self.config.num_attention_heads, |
|
self.config.attention_head_dim, |
|
cross_attention_dim=self.config.cross_attention_dim, |
|
upcast_attention=self.config.upcast_attention, |
|
use_attention_ffn=self.config.use_attention_ffn, |
|
) |
|
for _ in range(self.config.num_layers) |
|
] |
|
) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = value |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
added_cond_kwargs: Dict[str, torch.Tensor] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
): |
|
""" |
|
The [`MatryoshkaTransformer2DModel`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): |
|
Input `hidden_states`. |
|
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): |
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to |
|
self-attention. |
|
timestep ( `torch.LongTensor`, *optional*): |
|
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. |
|
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): |
|
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in |
|
`AdaLayerZeroNorm`. |
|
cross_attention_kwargs ( `Dict[str, Any]`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
attention_mask ( `torch.Tensor`, *optional*): |
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
|
negative values to the attention scores corresponding to "discard" tokens. |
|
encoder_attention_mask ( `torch.Tensor`, *optional*): |
|
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: |
|
|
|
* Mask `(batch, sequence_length)` True = keep, False = discard. |
|
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. |
|
|
|
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format |
|
above. This bias will be added to the cross-attention scores. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
If `return_dict` is True, an [`~MatryoshkaTransformer2DModelOutput`] is returned, |
|
otherwise a `tuple` where the first element is the sample tensor. |
|
""" |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.ndim == 2: |
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
|
|
for block in self.transformer_blocks: |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
timestep, |
|
cross_attention_kwargs, |
|
class_labels, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states = block( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
timestep=timestep, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
class_labels=class_labels, |
|
) |
|
|
|
|
|
output = hidden_states |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return MatryoshkaTransformer2DModelOutput(sample=output) |
|
|
|
|
|
class MatryoshkaTransformerBlock(nn.Module): |
|
r""" |
|
Matryoshka Transformer block. |
|
|
|
Parameters: |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_attention_heads: int, |
|
attention_head_dim: int, |
|
cross_attention_dim: Optional[int] = None, |
|
upcast_attention: bool = False, |
|
use_attention_ffn: bool = True, |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_head_dim = attention_head_dim |
|
self.cross_attention_dim = cross_attention_dim |
|
|
|
|
|
|
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=None, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
norm_num_groups=32, |
|
bias=True, |
|
upcast_attention=upcast_attention, |
|
pre_only=True, |
|
processor=MatryoshkaFusedAttnProcessor2_0(), |
|
) |
|
self.attn1.fuse_projections() |
|
del self.attn1.to_q |
|
del self.attn1.to_k |
|
del self.attn1.to_v |
|
|
|
|
|
if cross_attention_dim is not None and cross_attention_dim > 0: |
|
self.attn2 = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=cross_attention_dim, |
|
cross_attention_norm="layer_norm", |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
bias=True, |
|
upcast_attention=upcast_attention, |
|
pre_only=True, |
|
processor=MatryoshkaFusedAttnProcessor2_0(), |
|
) |
|
self.attn2.fuse_projections() |
|
del self.attn2.to_q |
|
del self.attn2.to_k |
|
del self.attn2.to_v |
|
|
|
self.proj_out = nn.Linear(dim, dim) |
|
|
|
if use_attention_ffn: |
|
|
|
self.ff = MatryoshkaFeedForward(dim) |
|
else: |
|
self.ff = None |
|
|
|
|
|
self._chunk_size = None |
|
self._chunk_dim = 0 |
|
|
|
|
|
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): |
|
|
|
self._chunk_size = chunk_size |
|
self._chunk_dim = dim |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
|
|
batch_size, channels, *spatial_dims = hidden_states.shape |
|
|
|
attn_output, query = self.attn1( |
|
hidden_states, |
|
|
|
) |
|
|
|
|
|
if self.cross_attention_dim is not None and self.cross_attention_dim > 0: |
|
attn_output_cond = self.attn2( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
self_attention_output=attn_output, |
|
self_attention_query=query, |
|
|
|
) |
|
|
|
attn_output_cond = self.proj_out(attn_output_cond) |
|
attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims) |
|
hidden_states = hidden_states + attn_output_cond |
|
|
|
if self.ff is not None: |
|
|
|
if self._chunk_size is not None: |
|
|
|
ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) |
|
else: |
|
ff_output = self.ff(hidden_states) |
|
|
|
hidden_states = ff_output + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
class MatryoshkaFusedAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses |
|
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. |
|
For cross-attention modules, key and value projection matrices are fused. |
|
|
|
<Tip warning={true}> |
|
|
|
This API is currently 🧪 experimental in nature and can change in future. |
|
|
|
</Tip> |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError( |
|
"MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x." |
|
) |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
temb: Optional[torch.Tensor] = None, |
|
self_attention_query: Optional[torch.Tensor] = None, |
|
self_attention_output: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
|
|
residual = hidden_states |
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states) |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous() |
|
|
|
if encoder_hidden_states is None: |
|
qkv = attn.to_qkv(hidden_states) |
|
split_size = qkv.shape[-1] // 3 |
|
query, key, value = torch.split(qkv, split_size, dim=-1) |
|
else: |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
if self_attention_query is not None: |
|
query = self_attention_query |
|
else: |
|
query = attn.to_q(hidden_states) |
|
|
|
kv = attn.to_kv(encoder_hidden_states) |
|
split_size = kv.shape[-1] // 2 |
|
key, value = torch.split(kv, split_size, dim=-1) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
if self_attention_output is None: |
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
if self_attention_output is not None: |
|
hidden_states = hidden_states + self_attention_output |
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states if self_attention_output is not None else (hidden_states, query) |
|
|
|
|
|
class MatryoshkaFeedForward(nn.Module): |
|
r""" |
|
A feed-forward layer for the Matryoshka models. |
|
|
|
Parameters:""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
): |
|
super().__init__() |
|
|
|
self.group_norm = nn.GroupNorm(32, dim) |
|
self.linear_gelu = GELU(dim, dim * 4) |
|
self.linear_out = nn.Linear(dim * 4, dim) |
|
|
|
def forward(self, x): |
|
batch_size, channels, *spatial_dims = x.shape |
|
x = self.group_norm(x) |
|
x = x.view(batch_size, channels, -1).permute(0, 2, 1) |
|
x = self.linear_out(self.linear_gelu(x)) |
|
x = x.permute(0, 2, 1).view(batch_size, channels, *spatial_dims) |
|
return x |
|
|
|
|
|
def get_down_block( |
|
down_block_type: str, |
|
num_layers: int, |
|
in_channels: int, |
|
out_channels: int, |
|
temb_channels: int, |
|
add_downsample: bool, |
|
resnet_eps: float, |
|
resnet_act_fn: str, |
|
norm_type: str = "layer_norm", |
|
transformer_layers_per_block: int = 1, |
|
num_attention_heads: Optional[int] = None, |
|
resnet_groups: Optional[int] = None, |
|
cross_attention_dim: Optional[int] = None, |
|
downsample_padding: Optional[int] = None, |
|
dual_cross_attention: bool = False, |
|
use_linear_projection: bool = False, |
|
only_cross_attention: bool = False, |
|
upcast_attention: bool = False, |
|
resnet_time_scale_shift: str = "default", |
|
attention_type: str = "default", |
|
attention_pre_only: bool = False, |
|
resnet_skip_time_act: bool = False, |
|
resnet_out_scale_factor: float = 1.0, |
|
cross_attention_norm: Optional[str] = None, |
|
attention_head_dim: Optional[int] = None, |
|
use_attention_ffn: bool = True, |
|
downsample_type: Optional[str] = None, |
|
dropout: float = 0.0, |
|
): |
|
|
|
if attention_head_dim is None: |
|
logger.warning( |
|
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." |
|
) |
|
attention_head_dim = num_attention_heads |
|
|
|
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type |
|
if down_block_type == "DownBlock2D": |
|
return DownBlock2D( |
|
num_layers=num_layers, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
dropout=dropout, |
|
add_downsample=add_downsample, |
|
resnet_eps=resnet_eps, |
|
resnet_act_fn=resnet_act_fn, |
|
resnet_groups=resnet_groups, |
|
downsample_padding=downsample_padding, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
) |
|
elif down_block_type == "CrossAttnDownBlock2D": |
|
if cross_attention_dim is None: |
|
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") |
|
return CrossAttnDownBlock2D( |
|
num_layers=num_layers, |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
dropout=dropout, |
|
add_downsample=add_downsample, |
|
resnet_eps=resnet_eps, |
|
resnet_act_fn=resnet_act_fn, |
|
norm_type=norm_type, |
|
resnet_groups=resnet_groups, |
|
downsample_padding=downsample_padding, |
|
cross_attention_dim=cross_attention_dim, |
|
cross_attention_norm=cross_attention_norm, |
|
num_attention_heads=num_attention_heads, |
|
dual_cross_attention=dual_cross_attention, |
|
use_linear_projection=use_linear_projection, |
|
only_cross_attention=only_cross_attention, |
|
upcast_attention=upcast_attention, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
attention_type=attention_type, |
|
attention_pre_only=attention_pre_only, |
|
use_attention_ffn=use_attention_ffn, |
|
) |
|
|
|
|
|
def get_mid_block( |
|
mid_block_type: str, |
|
temb_channels: int, |
|
in_channels: int, |
|
resnet_eps: float, |
|
resnet_act_fn: str, |
|
resnet_groups: int, |
|
norm_type: str = "layer_norm", |
|
output_scale_factor: float = 1.0, |
|
transformer_layers_per_block: int = 1, |
|
num_attention_heads: Optional[int] = None, |
|
cross_attention_dim: Optional[int] = None, |
|
dual_cross_attention: bool = False, |
|
use_linear_projection: bool = False, |
|
mid_block_only_cross_attention: bool = False, |
|
upcast_attention: bool = False, |
|
resnet_time_scale_shift: str = "default", |
|
attention_type: str = "default", |
|
attention_pre_only: bool = False, |
|
resnet_skip_time_act: bool = False, |
|
cross_attention_norm: Optional[str] = None, |
|
attention_head_dim: Optional[int] = 1, |
|
dropout: float = 0.0, |
|
): |
|
if mid_block_type == "UNetMidBlock2DCrossAttn": |
|
return UNetMidBlock2DCrossAttn( |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
in_channels=in_channels, |
|
temb_channels=temb_channels, |
|
dropout=dropout, |
|
resnet_eps=resnet_eps, |
|
resnet_act_fn=resnet_act_fn, |
|
norm_type=norm_type, |
|
output_scale_factor=output_scale_factor, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
cross_attention_dim=cross_attention_dim, |
|
cross_attention_norm=cross_attention_norm, |
|
num_attention_heads=num_attention_heads, |
|
resnet_groups=resnet_groups, |
|
dual_cross_attention=dual_cross_attention, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
attention_type=attention_type, |
|
attention_pre_only=attention_pre_only, |
|
) |
|
|
|
|
|
def get_up_block( |
|
up_block_type: str, |
|
num_layers: int, |
|
in_channels: int, |
|
out_channels: int, |
|
prev_output_channel: int, |
|
temb_channels: int, |
|
add_upsample: bool, |
|
resnet_eps: float, |
|
resnet_act_fn: str, |
|
norm_type: str = "layer_norm", |
|
resolution_idx: Optional[int] = None, |
|
transformer_layers_per_block: int = 1, |
|
num_attention_heads: Optional[int] = None, |
|
resnet_groups: Optional[int] = None, |
|
cross_attention_dim: Optional[int] = None, |
|
dual_cross_attention: bool = False, |
|
use_linear_projection: bool = False, |
|
only_cross_attention: bool = False, |
|
upcast_attention: bool = False, |
|
resnet_time_scale_shift: str = "default", |
|
attention_type: str = "default", |
|
attention_pre_only: bool = False, |
|
resnet_skip_time_act: bool = False, |
|
resnet_out_scale_factor: float = 1.0, |
|
cross_attention_norm: Optional[str] = None, |
|
attention_head_dim: Optional[int] = None, |
|
use_attention_ffn: bool = True, |
|
upsample_type: Optional[str] = None, |
|
dropout: float = 0.0, |
|
) -> nn.Module: |
|
|
|
if attention_head_dim is None: |
|
logger.warning( |
|
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." |
|
) |
|
attention_head_dim = num_attention_heads |
|
|
|
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type |
|
if up_block_type == "UpBlock2D": |
|
return UpBlock2D( |
|
num_layers=num_layers, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
prev_output_channel=prev_output_channel, |
|
temb_channels=temb_channels, |
|
resolution_idx=resolution_idx, |
|
dropout=dropout, |
|
add_upsample=add_upsample, |
|
resnet_eps=resnet_eps, |
|
resnet_act_fn=resnet_act_fn, |
|
resnet_groups=resnet_groups, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
) |
|
elif up_block_type == "CrossAttnUpBlock2D": |
|
if cross_attention_dim is None: |
|
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") |
|
return CrossAttnUpBlock2D( |
|
num_layers=num_layers, |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
prev_output_channel=prev_output_channel, |
|
temb_channels=temb_channels, |
|
resolution_idx=resolution_idx, |
|
dropout=dropout, |
|
add_upsample=add_upsample, |
|
resnet_eps=resnet_eps, |
|
resnet_act_fn=resnet_act_fn, |
|
norm_type=norm_type, |
|
resnet_groups=resnet_groups, |
|
cross_attention_dim=cross_attention_dim, |
|
cross_attention_norm=cross_attention_norm, |
|
num_attention_heads=num_attention_heads, |
|
dual_cross_attention=dual_cross_attention, |
|
use_linear_projection=use_linear_projection, |
|
only_cross_attention=only_cross_attention, |
|
upcast_attention=upcast_attention, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
attention_type=attention_type, |
|
attention_pre_only=attention_pre_only, |
|
use_attention_ffn=use_attention_ffn, |
|
) |
|
|
|
|
|
class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): |
|
def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, type): |
|
super().__init__() |
|
if type == "unet": |
|
self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) |
|
elif type == "nested_unet": |
|
self.cond_emb = None |
|
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) |
|
self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) |
|
|
|
def forward(self, emb, encoder_hidden_states, added_cond_kwargs): |
|
conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) |
|
masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) |
|
if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): |
|
if conditioning_mask is None: |
|
y = encoder_hidden_states.mean(dim=1) |
|
else: |
|
y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( |
|
dim=1, keepdim=True |
|
) |
|
cond_emb = self.cond_emb(y) |
|
else: |
|
cond_emb = None |
|
|
|
if not masked_cross_attention: |
|
conditioning_mask = None |
|
|
|
micro = added_cond_kwargs.get("micro_conditioning_scale", None) |
|
if micro is not None: |
|
temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) |
|
temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) |
|
|
|
return temb_micro_conditioning, conditioning_mask, cond_emb |
|
|
|
return None, conditioning_mask, cond_emb |
|
|
|
|
|
@dataclass |
|
class MatryoshkaUNet2DConditionOutput(BaseOutput): |
|
""" |
|
The output of [`MatryoshkaUNet2DConditionOutput`]. |
|
|
|
Args: |
|
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): |
|
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. |
|
""" |
|
|
|
sample: torch.Tensor = None |
|
sample_inner: torch.Tensor = None |
|
|
|
|
|
class MatryoshkaUNet2DConditionModel( |
|
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin |
|
): |
|
r""" |
|
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample |
|
shaped output. |
|
|
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
|
for all models (such as downloading or saving). |
|
|
|
Parameters: |
|
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): |
|
Height and width of input/output sample. |
|
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. |
|
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. |
|
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. |
|
flip_sin_to_cos (`bool`, *optional*, defaults to `True`): |
|
Whether to flip the sin to cos in the time embedding. |
|
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. |
|
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): |
|
The tuple of downsample blocks to use. |
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): |
|
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or |
|
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. |
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): |
|
The tuple of upsample blocks to use. |
|
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): |
|
Whether to include self-attention in the basic transformer blocks, see |
|
[`~models.attention.BasicTransformerBlock`]. |
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): |
|
The tuple of output channels for each block. |
|
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. |
|
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. |
|
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. |
|
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. |
|
If `None`, normalization and activation layers is skipped in post-processing. |
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. |
|
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): |
|
The dimension of the cross attention features. |
|
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): |
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for |
|
[`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], |
|
[`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. |
|
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): |
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling |
|
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for |
|
[`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], |
|
[`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. |
|
encoder_hid_dim (`int`, *optional*, defaults to None): |
|
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` |
|
dimension to `cross_attention_dim`. |
|
encoder_hid_dim_type (`str`, *optional*, defaults to `None`): |
|
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text |
|
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. |
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. |
|
num_attention_heads (`int`, *optional*): |
|
The number of attention heads. If not defined, defaults to `attention_head_dim` |
|
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config |
|
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. |
|
class_embed_type (`str`, *optional*, defaults to `None`): |
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, |
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. |
|
addition_embed_type (`str`, *optional*, defaults to `None`): |
|
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or |
|
"text". "text" will use the `TextTimeEmbedding` layer. |
|
addition_time_embed_dim: (`int`, *optional*, defaults to `None`): |
|
Dimension for the timestep embeddings. |
|
num_class_embeds (`int`, *optional*, defaults to `None`): |
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing |
|
class conditioning with `class_embed_type` equal to `None`. |
|
time_embedding_type (`str`, *optional*, defaults to `positional`): |
|
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. |
|
time_embedding_dim (`int`, *optional*, defaults to `None`): |
|
An optional override for the dimension of the projected time embedding. |
|
time_embedding_act_fn (`str`, *optional*, defaults to `None`): |
|
Optional activation function to use only once on the time embeddings before they are passed to the rest of |
|
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. |
|
timestep_post_act (`str`, *optional*, defaults to `None`): |
|
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. |
|
time_cond_proj_dim (`int`, *optional*, defaults to `None`): |
|
The dimension of `cond_proj` layer in the timestep embedding. |
|
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. |
|
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. |
|
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when |
|
`class_embed_type="projection"`. Required when `class_embed_type="projection"`. |
|
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time |
|
embeddings with the class embeddings. |
|
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): |
|
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If |
|
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the |
|
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` |
|
otherwise. |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
_no_split_modules = ["MatryoshkaTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
sample_size: Optional[int] = None, |
|
in_channels: int = 3, |
|
out_channels: int = 3, |
|
center_input_sample: bool = False, |
|
flip_sin_to_cos: bool = True, |
|
freq_shift: int = 0, |
|
down_block_types: Tuple[str] = ( |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"DownBlock2D", |
|
), |
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", |
|
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), |
|
only_cross_attention: Union[bool, Tuple[bool]] = False, |
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), |
|
layers_per_block: Union[int, Tuple[int]] = 2, |
|
downsample_padding: int = 1, |
|
mid_block_scale_factor: float = 1, |
|
dropout: float = 0.0, |
|
act_fn: str = "silu", |
|
norm_type: str = "layer_norm", |
|
norm_num_groups: Optional[int] = 32, |
|
norm_eps: float = 1e-5, |
|
cross_attention_dim: Union[int, Tuple[int]] = 1280, |
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, |
|
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, |
|
encoder_hid_dim: Optional[int] = None, |
|
encoder_hid_dim_type: Optional[str] = None, |
|
attention_head_dim: Union[int, Tuple[int]] = 8, |
|
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, |
|
dual_cross_attention: bool = False, |
|
use_attention_ffn: bool = True, |
|
use_linear_projection: bool = False, |
|
class_embed_type: Optional[str] = None, |
|
addition_embed_type: Optional[str] = None, |
|
addition_time_embed_dim: Optional[int] = None, |
|
num_class_embeds: Optional[int] = None, |
|
upcast_attention: bool = False, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_skip_time_act: bool = False, |
|
resnet_out_scale_factor: float = 1.0, |
|
time_embedding_type: str = "positional", |
|
time_embedding_dim: Optional[int] = None, |
|
time_embedding_act_fn: Optional[str] = None, |
|
timestep_post_act: Optional[str] = None, |
|
time_cond_proj_dim: Optional[int] = None, |
|
conv_in_kernel: int = 3, |
|
conv_out_kernel: int = 3, |
|
projection_class_embeddings_input_dim: Optional[int] = None, |
|
attention_type: str = "default", |
|
attention_pre_only: bool = False, |
|
masked_cross_attention: bool = False, |
|
micro_conditioning_scale: int = None, |
|
class_embeddings_concat: bool = False, |
|
mid_block_only_cross_attention: Optional[bool] = None, |
|
cross_attention_norm: Optional[str] = None, |
|
addition_embed_type_num_heads: int = 64, |
|
temporal_mode: bool = False, |
|
temporal_spatial_ds: bool = False, |
|
skip_cond_emb: bool = False, |
|
nesting: Optional[int] = False, |
|
): |
|
super().__init__() |
|
|
|
self.sample_size = sample_size |
|
|
|
if num_attention_heads is not None: |
|
raise ValueError( |
|
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_attention_heads = num_attention_heads or attention_head_dim |
|
|
|
|
|
self._check_config( |
|
down_block_types=down_block_types, |
|
up_block_types=up_block_types, |
|
only_cross_attention=only_cross_attention, |
|
block_out_channels=block_out_channels, |
|
layers_per_block=layers_per_block, |
|
cross_attention_dim=cross_attention_dim, |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, |
|
attention_head_dim=attention_head_dim, |
|
num_attention_heads=num_attention_heads, |
|
) |
|
|
|
|
|
conv_in_padding = (conv_in_kernel - 1) // 2 |
|
self.conv_in = nn.Conv2d( |
|
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding |
|
) |
|
|
|
|
|
time_embed_dim, timestep_input_dim = self._set_time_proj( |
|
time_embedding_type, |
|
block_out_channels=block_out_channels, |
|
flip_sin_to_cos=flip_sin_to_cos, |
|
freq_shift=freq_shift, |
|
time_embedding_dim=time_embedding_dim, |
|
) |
|
|
|
self.time_embedding = TimestepEmbedding( |
|
time_embedding_dim // 4 if time_embedding_dim is not None else timestep_input_dim, |
|
time_embed_dim, |
|
act_fn=act_fn, |
|
post_act_fn=timestep_post_act, |
|
cond_proj_dim=time_cond_proj_dim, |
|
) |
|
|
|
self._set_encoder_hid_proj( |
|
encoder_hid_dim_type, |
|
cross_attention_dim=cross_attention_dim, |
|
encoder_hid_dim=encoder_hid_dim, |
|
) |
|
|
|
|
|
self._set_class_embedding( |
|
class_embed_type, |
|
act_fn=act_fn, |
|
num_class_embeds=num_class_embeds, |
|
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, |
|
time_embed_dim=time_embed_dim, |
|
timestep_input_dim=timestep_input_dim, |
|
) |
|
|
|
self._set_add_embedding( |
|
addition_embed_type, |
|
addition_embed_type_num_heads=addition_embed_type_num_heads, |
|
addition_time_embed_dim=timestep_input_dim, |
|
cross_attention_dim=cross_attention_dim, |
|
encoder_hid_dim=encoder_hid_dim, |
|
flip_sin_to_cos=flip_sin_to_cos, |
|
freq_shift=freq_shift, |
|
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, |
|
time_embed_dim=time_embed_dim, |
|
) |
|
|
|
if time_embedding_act_fn is None: |
|
self.time_embed_act = None |
|
else: |
|
self.time_embed_act = get_activation(time_embedding_act_fn) |
|
|
|
self.down_blocks = nn.ModuleList([]) |
|
self.up_blocks = nn.ModuleList([]) |
|
|
|
if isinstance(only_cross_attention, bool): |
|
if mid_block_only_cross_attention is None: |
|
mid_block_only_cross_attention = only_cross_attention |
|
|
|
only_cross_attention = [only_cross_attention] * len(down_block_types) |
|
|
|
if mid_block_only_cross_attention is None: |
|
mid_block_only_cross_attention = False |
|
|
|
if isinstance(num_attention_heads, int): |
|
num_attention_heads = (num_attention_heads,) * len(down_block_types) |
|
|
|
if isinstance(attention_head_dim, int): |
|
attention_head_dim = (attention_head_dim,) * len(down_block_types) |
|
|
|
if isinstance(cross_attention_dim, int): |
|
cross_attention_dim = (cross_attention_dim,) * len(down_block_types) |
|
|
|
if isinstance(layers_per_block, int): |
|
layers_per_block = [layers_per_block] * len(down_block_types) |
|
|
|
if isinstance(transformer_layers_per_block, int): |
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) |
|
|
|
if class_embeddings_concat: |
|
|
|
|
|
|
|
blocks_time_embed_dim = time_embed_dim * 2 |
|
else: |
|
blocks_time_embed_dim = time_embed_dim |
|
|
|
|
|
output_channel = block_out_channels[0] |
|
for i, down_block_type in enumerate(down_block_types): |
|
input_channel = output_channel |
|
output_channel = block_out_channels[i] |
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
down_block = get_down_block( |
|
down_block_type, |
|
num_layers=layers_per_block[i], |
|
transformer_layers_per_block=transformer_layers_per_block[i], |
|
in_channels=input_channel, |
|
out_channels=output_channel, |
|
temb_channels=blocks_time_embed_dim, |
|
add_downsample=not is_final_block, |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
norm_type=norm_type, |
|
resnet_groups=norm_num_groups, |
|
cross_attention_dim=cross_attention_dim[i], |
|
num_attention_heads=num_attention_heads[i], |
|
downsample_padding=downsample_padding, |
|
dual_cross_attention=dual_cross_attention, |
|
use_linear_projection=use_linear_projection, |
|
only_cross_attention=only_cross_attention[i], |
|
upcast_attention=upcast_attention, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
attention_type=attention_type, |
|
attention_pre_only=attention_pre_only, |
|
resnet_skip_time_act=resnet_skip_time_act, |
|
resnet_out_scale_factor=resnet_out_scale_factor, |
|
cross_attention_norm=cross_attention_norm, |
|
use_attention_ffn=use_attention_ffn, |
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, |
|
dropout=dropout, |
|
) |
|
self.down_blocks.append(down_block) |
|
|
|
|
|
self.mid_block = get_mid_block( |
|
mid_block_type, |
|
temb_channels=blocks_time_embed_dim, |
|
in_channels=block_out_channels[-1], |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
norm_type=norm_type, |
|
resnet_groups=norm_num_groups, |
|
output_scale_factor=mid_block_scale_factor, |
|
transformer_layers_per_block=1, |
|
num_attention_heads=num_attention_heads[-1], |
|
cross_attention_dim=cross_attention_dim[-1], |
|
dual_cross_attention=dual_cross_attention, |
|
use_linear_projection=use_linear_projection, |
|
mid_block_only_cross_attention=mid_block_only_cross_attention, |
|
upcast_attention=upcast_attention, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
attention_type=attention_type, |
|
attention_pre_only=attention_pre_only, |
|
resnet_skip_time_act=resnet_skip_time_act, |
|
cross_attention_norm=cross_attention_norm, |
|
attention_head_dim=attention_head_dim[-1], |
|
dropout=dropout, |
|
) |
|
|
|
|
|
self.num_upsamplers = 0 |
|
|
|
|
|
reversed_block_out_channels = list(reversed(block_out_channels)) |
|
reversed_num_attention_heads = list(reversed(num_attention_heads)) |
|
reversed_layers_per_block = list(reversed(layers_per_block)) |
|
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) |
|
reversed_transformer_layers_per_block = ( |
|
list(reversed(transformer_layers_per_block)) |
|
if reverse_transformer_layers_per_block is None |
|
else reverse_transformer_layers_per_block |
|
) |
|
only_cross_attention = list(reversed(only_cross_attention)) |
|
|
|
output_channel = reversed_block_out_channels[0] |
|
for i, up_block_type in enumerate(up_block_types): |
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
prev_output_channel = output_channel |
|
output_channel = reversed_block_out_channels[i] |
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] |
|
|
|
|
|
if not is_final_block: |
|
add_upsample = True |
|
self.num_upsamplers += 1 |
|
else: |
|
add_upsample = False |
|
|
|
up_block = get_up_block( |
|
up_block_type, |
|
num_layers=reversed_layers_per_block[i] + 1, |
|
transformer_layers_per_block=reversed_transformer_layers_per_block[i], |
|
in_channels=input_channel, |
|
out_channels=output_channel, |
|
prev_output_channel=prev_output_channel, |
|
temb_channels=blocks_time_embed_dim, |
|
add_upsample=add_upsample, |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
norm_type=norm_type, |
|
resolution_idx=i, |
|
resnet_groups=norm_num_groups, |
|
cross_attention_dim=reversed_cross_attention_dim[i], |
|
num_attention_heads=reversed_num_attention_heads[i], |
|
dual_cross_attention=dual_cross_attention, |
|
use_linear_projection=use_linear_projection, |
|
only_cross_attention=only_cross_attention[i], |
|
upcast_attention=upcast_attention, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
attention_type=attention_type, |
|
attention_pre_only=attention_pre_only, |
|
resnet_skip_time_act=resnet_skip_time_act, |
|
resnet_out_scale_factor=resnet_out_scale_factor, |
|
cross_attention_norm=cross_attention_norm, |
|
use_attention_ffn=use_attention_ffn, |
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, |
|
dropout=dropout, |
|
) |
|
self.up_blocks.append(up_block) |
|
|
|
|
|
if norm_num_groups is not None: |
|
self.conv_norm_out = nn.GroupNorm( |
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps |
|
) |
|
|
|
self.conv_act = get_activation(act_fn) |
|
|
|
else: |
|
self.conv_norm_out = None |
|
self.conv_act = None |
|
|
|
conv_out_padding = (conv_out_kernel - 1) // 2 |
|
self.conv_out = nn.Conv2d( |
|
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding |
|
) |
|
|
|
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) |
|
|
|
self.is_temporal = [] |
|
|
|
def _check_config( |
|
self, |
|
down_block_types: Tuple[str], |
|
up_block_types: Tuple[str], |
|
only_cross_attention: Union[bool, Tuple[bool]], |
|
block_out_channels: Tuple[int], |
|
layers_per_block: Union[int, Tuple[int]], |
|
cross_attention_dim: Union[int, Tuple[int]], |
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], |
|
reverse_transformer_layers_per_block: bool, |
|
attention_head_dim: int, |
|
num_attention_heads: Optional[Union[int, Tuple[int]]], |
|
): |
|
if len(down_block_types) != len(up_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." |
|
) |
|
|
|
if len(block_out_channels) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." |
|
) |
|
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: |
|
for layer_number_per_block in transformer_layers_per_block: |
|
if isinstance(layer_number_per_block, list): |
|
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") |
|
|
|
def _set_time_proj( |
|
self, |
|
time_embedding_type: str, |
|
block_out_channels: int, |
|
flip_sin_to_cos: bool, |
|
freq_shift: float, |
|
time_embedding_dim: int, |
|
) -> Tuple[int, int]: |
|
if time_embedding_type == "fourier": |
|
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 |
|
if time_embed_dim % 2 != 0: |
|
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") |
|
self.time_proj = GaussianFourierProjection( |
|
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos |
|
) |
|
timestep_input_dim = time_embed_dim |
|
elif time_embedding_type == "positional": |
|
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 |
|
|
|
if self.model_type == "unet": |
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) |
|
elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 256: |
|
self.time_proj = Timesteps(block_out_channels[0] * 4, flip_sin_to_cos, freq_shift) |
|
elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 1024: |
|
self.time_proj = Timesteps(block_out_channels[0] * 4 * 2, flip_sin_to_cos, freq_shift) |
|
timestep_input_dim = block_out_channels[0] |
|
else: |
|
raise ValueError( |
|
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." |
|
) |
|
|
|
return time_embed_dim, timestep_input_dim |
|
|
|
def _set_encoder_hid_proj( |
|
self, |
|
encoder_hid_dim_type: Optional[str], |
|
cross_attention_dim: Union[int, Tuple[int]], |
|
encoder_hid_dim: Optional[int], |
|
): |
|
if encoder_hid_dim_type is None and encoder_hid_dim is not None: |
|
encoder_hid_dim_type = "text_proj" |
|
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) |
|
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") |
|
|
|
if encoder_hid_dim is None and encoder_hid_dim_type is not None: |
|
raise ValueError( |
|
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." |
|
) |
|
|
|
if encoder_hid_dim_type == "text_proj": |
|
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) |
|
elif encoder_hid_dim_type == "text_image_proj": |
|
|
|
|
|
|
|
self.encoder_hid_proj = TextImageProjection( |
|
text_embed_dim=encoder_hid_dim, |
|
image_embed_dim=cross_attention_dim, |
|
cross_attention_dim=cross_attention_dim, |
|
) |
|
elif encoder_hid_dim_type == "image_proj": |
|
|
|
self.encoder_hid_proj = ImageProjection( |
|
image_embed_dim=encoder_hid_dim, |
|
cross_attention_dim=cross_attention_dim, |
|
) |
|
elif encoder_hid_dim_type is not None: |
|
raise ValueError( |
|
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." |
|
) |
|
else: |
|
self.encoder_hid_proj = None |
|
|
|
def _set_class_embedding( |
|
self, |
|
class_embed_type: Optional[str], |
|
act_fn: str, |
|
num_class_embeds: Optional[int], |
|
projection_class_embeddings_input_dim: Optional[int], |
|
time_embed_dim: int, |
|
timestep_input_dim: int, |
|
): |
|
if class_embed_type is None and num_class_embeds is not None: |
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) |
|
elif class_embed_type == "timestep": |
|
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) |
|
elif class_embed_type == "identity": |
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) |
|
elif class_embed_type == "projection": |
|
if projection_class_embeddings_input_dim is None: |
|
raise ValueError( |
|
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) |
|
elif class_embed_type == "simple_projection": |
|
if projection_class_embeddings_input_dim is None: |
|
raise ValueError( |
|
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" |
|
) |
|
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) |
|
else: |
|
self.class_embedding = None |
|
|
|
def _set_add_embedding( |
|
self, |
|
addition_embed_type: str, |
|
addition_embed_type_num_heads: int, |
|
addition_time_embed_dim: Optional[int], |
|
flip_sin_to_cos: bool, |
|
freq_shift: float, |
|
cross_attention_dim: Optional[int], |
|
encoder_hid_dim: Optional[int], |
|
projection_class_embeddings_input_dim: Optional[int], |
|
time_embed_dim: int, |
|
): |
|
if addition_embed_type == "text": |
|
if encoder_hid_dim is not None: |
|
text_time_embedding_from_dim = encoder_hid_dim |
|
else: |
|
text_time_embedding_from_dim = cross_attention_dim |
|
|
|
self.add_embedding = TextTimeEmbedding( |
|
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads |
|
) |
|
elif addition_embed_type == "matryoshka": |
|
self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( |
|
self.config.time_embedding_dim // 4 |
|
if self.config.time_embedding_dim is not None |
|
else addition_time_embed_dim, |
|
cross_attention_dim, |
|
time_embed_dim, |
|
self.model_type, |
|
) |
|
elif addition_embed_type == "text_image": |
|
|
|
|
|
|
|
self.add_embedding = TextImageTimeEmbedding( |
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim |
|
) |
|
elif addition_embed_type == "text_time": |
|
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) |
|
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) |
|
elif addition_embed_type == "image": |
|
|
|
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) |
|
elif addition_embed_type == "image_hint": |
|
|
|
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) |
|
elif addition_embed_type is not None: |
|
raise ValueError( |
|
f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." |
|
) |
|
|
|
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): |
|
if attention_type in ["gated", "gated-text-image"]: |
|
positive_len = 768 |
|
if isinstance(cross_attention_dim, int): |
|
positive_len = cross_attention_dim |
|
elif isinstance(cross_attention_dim, (list, tuple)): |
|
positive_len = cross_attention_dim[0] |
|
|
|
feature_type = "text-only" if attention_type == "gated" else "text-image" |
|
self.position_net = GLIGENTextBoundingboxProjection( |
|
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type |
|
) |
|
|
|
@property |
|
def attn_processors(self) -> Dict[str, AttentionProcessor]: |
|
r""" |
|
Returns: |
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with |
|
indexed by its weight name. |
|
""" |
|
|
|
processors = {} |
|
|
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
|
if hasattr(module, "get_processor"): |
|
processors[f"{name}.processor"] = module.get_processor() |
|
|
|
for sub_name, child in module.named_children(): |
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
|
|
|
return processors |
|
|
|
for name, module in self.named_children(): |
|
fn_recursive_add_processors(name, module, processors) |
|
|
|
return processors |
|
|
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
|
r""" |
|
Sets the attention processor to use to compute attention. |
|
|
|
Parameters: |
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor |
|
for **all** `Attention` layers. |
|
|
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
|
processor. This is strongly recommended when setting trainable attention processors. |
|
|
|
""" |
|
count = len(self.attn_processors.keys()) |
|
|
|
if isinstance(processor, dict) and len(processor) != count: |
|
raise ValueError( |
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
|
) |
|
|
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
|
if hasattr(module, "set_processor"): |
|
if not isinstance(processor, dict): |
|
module.set_processor(processor) |
|
else: |
|
module.set_processor(processor.pop(f"{name}.processor")) |
|
|
|
for sub_name, child in module.named_children(): |
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
|
|
|
for name, module in self.named_children(): |
|
fn_recursive_attn_processor(name, module, processor) |
|
|
|
def set_default_attn_processor(self): |
|
""" |
|
Disables custom attention processors and sets the default attention implementation. |
|
""" |
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnAddedKVProcessor() |
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnProcessor() |
|
else: |
|
raise ValueError( |
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" |
|
) |
|
|
|
self.set_attn_processor(processor) |
|
|
|
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): |
|
r""" |
|
Enable sliced attention computation. |
|
|
|
When this option is enabled, the attention module splits the input tensor in slices to compute attention in |
|
several steps. This is useful for saving some memory in exchange for a small decrease in speed. |
|
|
|
Args: |
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): |
|
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If |
|
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is |
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` |
|
must be a multiple of `slice_size`. |
|
""" |
|
sliceable_head_dims = [] |
|
|
|
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): |
|
if hasattr(module, "set_attention_slice"): |
|
sliceable_head_dims.append(module.sliceable_head_dim) |
|
|
|
for child in module.children(): |
|
fn_recursive_retrieve_sliceable_dims(child) |
|
|
|
|
|
for module in self.children(): |
|
fn_recursive_retrieve_sliceable_dims(module) |
|
|
|
num_sliceable_layers = len(sliceable_head_dims) |
|
|
|
if slice_size == "auto": |
|
|
|
|
|
slice_size = [dim // 2 for dim in sliceable_head_dims] |
|
elif slice_size == "max": |
|
|
|
slice_size = num_sliceable_layers * [1] |
|
|
|
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size |
|
|
|
if len(slice_size) != len(sliceable_head_dims): |
|
raise ValueError( |
|
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" |
|
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." |
|
) |
|
|
|
for i in range(len(slice_size)): |
|
size = slice_size[i] |
|
dim = sliceable_head_dims[i] |
|
if size is not None and size > dim: |
|
raise ValueError(f"size {size} has to be smaller or equal to {dim}.") |
|
|
|
|
|
|
|
|
|
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): |
|
if hasattr(module, "set_attention_slice"): |
|
module.set_attention_slice(slice_size.pop()) |
|
|
|
for child in module.children(): |
|
fn_recursive_set_attention_slice(child, slice_size) |
|
|
|
reversed_slice_size = list(reversed(slice_size)) |
|
for module in self.children(): |
|
fn_recursive_set_attention_slice(module, reversed_slice_size) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = value |
|
|
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): |
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. |
|
|
|
The suffixes after the scaling factors represent the stage blocks where they are being applied. |
|
|
|
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that |
|
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. |
|
|
|
Args: |
|
s1 (`float`): |
|
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to |
|
mitigate the "oversmoothing effect" in the enhanced denoising process. |
|
s2 (`float`): |
|
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to |
|
mitigate the "oversmoothing effect" in the enhanced denoising process. |
|
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. |
|
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. |
|
""" |
|
for i, upsample_block in enumerate(self.up_blocks): |
|
setattr(upsample_block, "s1", s1) |
|
setattr(upsample_block, "s2", s2) |
|
setattr(upsample_block, "b1", b1) |
|
setattr(upsample_block, "b2", b2) |
|
|
|
def disable_freeu(self): |
|
"""Disables the FreeU mechanism.""" |
|
freeu_keys = {"s1", "s2", "b1", "b2"} |
|
for i, upsample_block in enumerate(self.up_blocks): |
|
for k in freeu_keys: |
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: |
|
setattr(upsample_block, k, None) |
|
|
|
def fuse_qkv_projections(self): |
|
""" |
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
|
are fused. For cross-attention modules, key and value projection matrices are fused. |
|
|
|
<Tip warning={true}> |
|
|
|
This API is 🧪 experimental. |
|
|
|
</Tip> |
|
""" |
|
self.original_attn_processors = None |
|
|
|
for _, attn_processor in self.attn_processors.items(): |
|
if "Added" in str(attn_processor.__class__.__name__): |
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
|
|
|
self.original_attn_processors = self.attn_processors |
|
|
|
for module in self.modules(): |
|
if isinstance(module, Attention): |
|
module.fuse_projections(fuse=True) |
|
|
|
self.set_attn_processor(FusedAttnProcessor2_0()) |
|
|
|
def unfuse_qkv_projections(self): |
|
"""Disables the fused QKV projection if enabled. |
|
|
|
<Tip warning={true}> |
|
|
|
This API is 🧪 experimental. |
|
|
|
</Tip> |
|
|
|
""" |
|
if self.original_attn_processors is not None: |
|
self.set_attn_processor(self.original_attn_processors) |
|
|
|
def get_time_embed( |
|
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] |
|
) -> Optional[torch.Tensor]: |
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
|
|
|
|
is_mps = sample.device.type == "mps" |
|
if isinstance(timestep, float): |
|
dtype = torch.float32 if is_mps else torch.float64 |
|
else: |
|
dtype = torch.int32 if is_mps else torch.int64 |
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
|
elif len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
|
|
timesteps = timesteps.expand(sample.shape[0]) |
|
|
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=sample.dtype) |
|
return t_emb |
|
|
|
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: |
|
class_emb = None |
|
if self.class_embedding is not None: |
|
if class_labels is None: |
|
raise ValueError("class_labels should be provided when num_class_embeds > 0") |
|
|
|
if self.config.class_embed_type == "timestep": |
|
class_labels = self.time_proj(class_labels) |
|
|
|
|
|
|
|
class_labels = class_labels.to(dtype=sample.dtype) |
|
|
|
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) |
|
return class_emb |
|
|
|
def get_aug_embed( |
|
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] |
|
) -> Optional[torch.Tensor]: |
|
aug_emb = None |
|
if self.config.addition_embed_type == "text": |
|
aug_emb = self.add_embedding(encoder_hidden_states) |
|
elif self.config.addition_embed_type == "matryoshka": |
|
aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) |
|
elif self.config.addition_embed_type == "text_image": |
|
|
|
if "image_embeds" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" |
|
) |
|
|
|
image_embs = added_cond_kwargs.get("image_embeds") |
|
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) |
|
aug_emb = self.add_embedding(text_embs, image_embs) |
|
elif self.config.addition_embed_type == "text_time": |
|
|
|
if "text_embeds" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" |
|
) |
|
text_embeds = added_cond_kwargs.get("text_embeds") |
|
if "time_ids" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" |
|
) |
|
time_ids = added_cond_kwargs.get("time_ids") |
|
time_embeds = self.add_time_proj(time_ids.flatten()) |
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) |
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) |
|
add_embeds = add_embeds.to(emb.dtype) |
|
aug_emb = self.add_embedding(add_embeds) |
|
elif self.config.addition_embed_type == "image": |
|
|
|
if "image_embeds" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" |
|
) |
|
image_embs = added_cond_kwargs.get("image_embeds") |
|
aug_emb = self.add_embedding(image_embs) |
|
elif self.config.addition_embed_type == "image_hint": |
|
|
|
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" |
|
) |
|
image_embs = added_cond_kwargs.get("image_embeds") |
|
hint = added_cond_kwargs.get("hint") |
|
aug_emb = self.add_embedding(image_embs, hint) |
|
return aug_emb |
|
|
|
def process_encoder_hidden_states( |
|
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] |
|
) -> torch.Tensor: |
|
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": |
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) |
|
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": |
|
|
|
if "image_embeds" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" |
|
) |
|
|
|
image_embeds = added_cond_kwargs.get("image_embeds") |
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) |
|
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": |
|
|
|
if "image_embeds" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" |
|
) |
|
image_embeds = added_cond_kwargs.get("image_embeds") |
|
encoder_hidden_states = self.encoder_hid_proj(image_embeds) |
|
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": |
|
if "image_embeds" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" |
|
) |
|
|
|
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: |
|
encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) |
|
|
|
image_embeds = added_cond_kwargs.get("image_embeds") |
|
image_embeds = self.encoder_hid_proj(image_embeds) |
|
encoder_hidden_states = (encoder_hidden_states, image_embeds) |
|
return encoder_hidden_states |
|
|
|
@property |
|
def model_type(self) -> str: |
|
return "unet" |
|
|
|
def forward( |
|
self, |
|
sample: torch.Tensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
cond_emb: Optional[torch.Tensor] = None, |
|
class_labels: Optional[torch.Tensor] = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
mid_block_additional_residual: Optional[torch.Tensor] = None, |
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
from_nested: bool = False, |
|
) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: |
|
r""" |
|
The [`NestedUNet2DConditionModel`] forward method. |
|
|
|
Args: |
|
sample (`torch.Tensor`): |
|
The noisy input tensor with the following shape `(batch, channel, height, width)`. |
|
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. |
|
encoder_hidden_states (`torch.Tensor`): |
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. |
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`): |
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. |
|
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): |
|
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed |
|
through the `self.time_embedding` layer to obtain the timestep embeddings. |
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): |
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
|
negative values to the attention scores corresponding to "discard" tokens. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
added_cond_kwargs: (`dict`, *optional*): |
|
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that |
|
are passed along to the UNet blocks. |
|
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): |
|
A tuple of tensors that if specified are added to the residuals of down unet blocks. |
|
mid_block_additional_residual: (`torch.Tensor`, *optional*): |
|
A tensor that if specified is added to the residual of the middle unet block. |
|
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): |
|
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) |
|
encoder_attention_mask (`torch.Tensor`): |
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If |
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, |
|
which adds large negative values to the attention scores corresponding to "discard" tokens. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
[`~NestedUNet2DConditionOutput`] or `tuple`: |
|
If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, |
|
otherwise a `tuple` is returned where the first element is the sample tensor. |
|
""" |
|
|
|
|
|
|
|
|
|
default_overall_up_factor = 2**self.num_upsamplers |
|
|
|
|
|
forward_upsample_size = False |
|
upsample_size = None |
|
|
|
if self.config.nesting: |
|
sample, sample_feat = sample |
|
if isinstance(sample, list) and len(sample) == 1: |
|
sample = sample[0] |
|
|
|
for dim in sample.shape[-2:]: |
|
if dim % default_overall_up_factor != 0: |
|
|
|
forward_upsample_size = True |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if self.config.center_input_sample: |
|
sample = 2 * sample - 1.0 |
|
|
|
|
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) |
|
if class_emb is not None: |
|
if self.config.class_embeddings_concat: |
|
emb = torch.cat([emb, class_emb], dim=-1) |
|
else: |
|
emb = emb + class_emb |
|
|
|
added_cond_kwargs = added_cond_kwargs or {} |
|
added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention |
|
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale |
|
added_cond_kwargs["from_nested"] = from_nested |
|
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask |
|
|
|
if not from_nested: |
|
encoder_hidden_states = self.process_encoder_hidden_states( |
|
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
aug_emb, encoder_attention_mask, cond_emb = self.get_aug_embed( |
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
else: |
|
aug_emb, encoder_attention_mask, _ = self.get_aug_embed( |
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
|
|
if encoder_attention_mask is not None: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
if self.config.addition_embed_type == "image_hint": |
|
aug_emb, hint = aug_emb |
|
sample = torch.cat([sample, hint], dim=1) |
|
|
|
emb = emb + aug_emb + cond_emb if aug_emb is not None else emb |
|
|
|
if self.time_embed_act is not None: |
|
emb = self.time_embed_act(emb) |
|
|
|
|
|
sample = self.conv_in(sample) |
|
if self.config.nesting: |
|
sample = sample + sample_feat |
|
|
|
|
|
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: |
|
cross_attention_kwargs = cross_attention_kwargs.copy() |
|
gligen_args = cross_attention_kwargs.pop("gligen") |
|
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} |
|
|
|
|
|
|
|
|
|
if cross_attention_kwargs is not None: |
|
cross_attention_kwargs = cross_attention_kwargs.copy() |
|
lora_scale = cross_attention_kwargs.pop("scale", 1.0) |
|
else: |
|
lora_scale = 1.0 |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
scale_lora_layers(self, lora_scale) |
|
|
|
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None |
|
|
|
is_adapter = down_intrablock_additional_residuals is not None |
|
|
|
|
|
|
|
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: |
|
deprecate( |
|
"T2I should not use down_block_additional_residuals", |
|
"1.3.0", |
|
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ |
|
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ |
|
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", |
|
standard_warn=False, |
|
) |
|
down_intrablock_additional_residuals = down_block_additional_residuals |
|
is_adapter = True |
|
|
|
down_block_res_samples = (sample,) |
|
for downsample_block in self.down_blocks: |
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
|
|
|
additional_residuals = {} |
|
if is_adapter and len(down_intrablock_additional_residuals) > 0: |
|
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) |
|
|
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
encoder_attention_mask=encoder_attention_mask, |
|
**additional_residuals, |
|
) |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
if is_adapter and len(down_intrablock_additional_residuals) > 0: |
|
sample += down_intrablock_additional_residuals.pop(0) |
|
|
|
down_block_res_samples += res_samples |
|
|
|
if is_controlnet: |
|
new_down_block_res_samples = () |
|
|
|
for down_block_res_sample, down_block_additional_residual in zip( |
|
down_block_res_samples, down_block_additional_residuals |
|
): |
|
down_block_res_sample = down_block_res_sample + down_block_additional_residual |
|
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) |
|
|
|
down_block_res_samples = new_down_block_res_samples |
|
|
|
|
|
if self.mid_block is not None: |
|
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: |
|
sample = self.mid_block( |
|
sample, |
|
emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
else: |
|
sample = self.mid_block(sample, emb) |
|
|
|
|
|
if ( |
|
is_adapter |
|
and len(down_intrablock_additional_residuals) > 0 |
|
and sample.shape == down_intrablock_additional_residuals[0].shape |
|
): |
|
sample += down_intrablock_additional_residuals.pop(0) |
|
|
|
if is_controlnet: |
|
sample = sample + mid_block_additional_residual |
|
|
|
|
|
for i, upsample_block in enumerate(self.up_blocks): |
|
is_final_block = i == len(self.up_blocks) - 1 |
|
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
|
|
|
|
if not is_final_block and forward_upsample_size: |
|
upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
upsample_size=upsample_size, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
else: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
upsample_size=upsample_size, |
|
) |
|
|
|
sample_inner = sample |
|
|
|
|
|
if self.conv_norm_out: |
|
sample = self.conv_norm_out(sample_inner) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
if not return_dict: |
|
return (sample,) |
|
|
|
if self.config.nesting: |
|
return MatryoshkaUNet2DConditionOutput(sample=sample, sample_inner=sample_inner) |
|
|
|
return MatryoshkaUNet2DConditionOutput(sample=sample) |
|
|
|
|
|
class NestedUNet2DConditionOutput(BaseOutput): |
|
""" |
|
Output type for the [`NestedUNet2DConditionModel`] model. |
|
""" |
|
|
|
sample: list = None |
|
sample_inner: torch.Tensor = None |
|
|
|
|
|
class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): |
|
""" |
|
Nested UNet model with condition for image denoising. |
|
""" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
in_channels=3, |
|
out_channels=3, |
|
block_out_channels=(64, 128, 256), |
|
cross_attention_dim=2048, |
|
resnet_time_scale_shift="scale_shift", |
|
down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"), |
|
up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"), |
|
mid_block_type=None, |
|
nesting=False, |
|
flip_sin_to_cos=False, |
|
transformer_layers_per_block=[0, 0, 0], |
|
layers_per_block=[2, 2, 1], |
|
masked_cross_attention=True, |
|
micro_conditioning_scale=256, |
|
addition_embed_type="matryoshka", |
|
skip_normalization=True, |
|
time_embedding_dim=1024, |
|
skip_inner_unet_input=False, |
|
temporal_mode=False, |
|
temporal_spatial_ds=False, |
|
initialize_inner_with_pretrained=None, |
|
use_attention_ffn=False, |
|
act_fn="silu", |
|
addition_embed_type_num_heads=64, |
|
addition_time_embed_dim=None, |
|
attention_head_dim=8, |
|
attention_pre_only=False, |
|
attention_type="default", |
|
center_input_sample=False, |
|
class_embed_type=None, |
|
class_embeddings_concat=False, |
|
conv_in_kernel=3, |
|
conv_out_kernel=3, |
|
cross_attention_norm=None, |
|
downsample_padding=1, |
|
dropout=0.0, |
|
dual_cross_attention=False, |
|
encoder_hid_dim=None, |
|
encoder_hid_dim_type=None, |
|
freq_shift=0, |
|
mid_block_only_cross_attention=None, |
|
mid_block_scale_factor=1, |
|
norm_eps=1e-05, |
|
norm_num_groups=32, |
|
norm_type="layer_norm", |
|
num_attention_heads=None, |
|
num_class_embeds=None, |
|
only_cross_attention=False, |
|
projection_class_embeddings_input_dim=None, |
|
resnet_out_scale_factor=1.0, |
|
resnet_skip_time_act=False, |
|
reverse_transformer_layers_per_block=None, |
|
sample_size=None, |
|
skip_cond_emb=False, |
|
time_cond_proj_dim=None, |
|
time_embedding_act_fn=None, |
|
time_embedding_type="positional", |
|
timestep_post_act=None, |
|
upcast_attention=False, |
|
use_linear_projection=False, |
|
is_temporal=None, |
|
inner_config={}, |
|
): |
|
super().__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
block_out_channels=block_out_channels, |
|
cross_attention_dim=cross_attention_dim, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
down_block_types=down_block_types, |
|
up_block_types=up_block_types, |
|
mid_block_type=mid_block_type, |
|
nesting=nesting, |
|
flip_sin_to_cos=flip_sin_to_cos, |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
layers_per_block=layers_per_block, |
|
masked_cross_attention=masked_cross_attention, |
|
micro_conditioning_scale=micro_conditioning_scale, |
|
addition_embed_type=addition_embed_type, |
|
time_embedding_dim=time_embedding_dim, |
|
temporal_mode=temporal_mode, |
|
temporal_spatial_ds=temporal_spatial_ds, |
|
use_attention_ffn=use_attention_ffn, |
|
sample_size=sample_size, |
|
) |
|
|
|
|
|
if "inner_config" not in self.config.inner_config: |
|
self.inner_unet = MatryoshkaUNet2DConditionModel(**self.config.inner_config) |
|
else: |
|
self.inner_unet = NestedUNet2DConditionModel(**self.config.inner_config) |
|
|
|
if not self.config.skip_inner_unet_input: |
|
self.in_adapter = nn.Conv2d( |
|
self.config.block_out_channels[-1], |
|
self.config.inner_config["block_out_channels"][0], |
|
kernel_size=3, |
|
padding=1, |
|
) |
|
else: |
|
self.in_adapter = None |
|
self.out_adapter = nn.Conv2d( |
|
self.config.inner_config["block_out_channels"][0], |
|
self.config.block_out_channels[-1], |
|
kernel_size=3, |
|
padding=1, |
|
) |
|
|
|
self.is_temporal = [self.config.temporal_mode and (not self.config.temporal_spatial_ds)] |
|
if hasattr(self.inner_unet, "is_temporal"): |
|
self.is_temporal = self.is_temporal + self.inner_unet.is_temporal |
|
|
|
nest_ratio = int(2 ** (len(self.config.block_out_channels) - 1)) |
|
if self.is_temporal[0]: |
|
nest_ratio = int(np.sqrt(nest_ratio)) |
|
if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": |
|
self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio |
|
else: |
|
self.nest_ratio = [nest_ratio] |
|
|
|
|
|
|
|
@property |
|
def model_type(self): |
|
return "nested_unet" |
|
|
|
def forward( |
|
self, |
|
sample: torch.Tensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
cond_emb: Optional[torch.Tensor] = None, |
|
from_nested: bool = False, |
|
class_labels: Optional[torch.Tensor] = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
mid_block_additional_residual: Optional[torch.Tensor] = None, |
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: |
|
r""" |
|
The [`NestedUNet2DConditionModel`] forward method. |
|
|
|
Args: |
|
sample (`torch.Tensor`): |
|
The noisy input tensor with the following shape `(batch, channel, height, width)`. |
|
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. |
|
encoder_hidden_states (`torch.Tensor`): |
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. |
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`): |
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. |
|
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): |
|
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed |
|
through the `self.time_embedding` layer to obtain the timestep embeddings. |
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): |
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
|
negative values to the attention scores corresponding to "discard" tokens. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
added_cond_kwargs: (`dict`, *optional*): |
|
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that |
|
are passed along to the UNet blocks. |
|
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): |
|
A tuple of tensors that if specified are added to the residuals of down unet blocks. |
|
mid_block_additional_residual: (`torch.Tensor`, *optional*): |
|
A tensor that if specified is added to the residual of the middle unet block. |
|
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): |
|
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) |
|
encoder_attention_mask (`torch.Tensor`): |
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If |
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, |
|
which adds large negative values to the attention scores corresponding to "discard" tokens. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
[`~NestedUNet2DConditionOutput`] or `tuple`: |
|
If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, |
|
otherwise a `tuple` is returned where the first element is the sample tensor. |
|
""" |
|
|
|
|
|
|
|
|
|
default_overall_up_factor = 2**self.num_upsamplers |
|
|
|
|
|
forward_upsample_size = False |
|
upsample_size = None |
|
|
|
if self.config.nesting: |
|
sample, sample_feat = sample |
|
if isinstance(sample, list) and len(sample) == 1: |
|
sample = sample[0] |
|
|
|
|
|
bsz = [x.size(0) for x in sample] |
|
bh, bl = bsz[0], bsz[1] |
|
x_t_low, sample = sample[1:], sample[0] |
|
|
|
for dim in sample.shape[-2:]: |
|
if dim % default_overall_up_factor != 0: |
|
|
|
forward_upsample_size = True |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if self.config.center_input_sample: |
|
sample = 2 * sample - 1.0 |
|
|
|
|
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) |
|
if class_emb is not None: |
|
if self.config.class_embeddings_concat: |
|
emb = torch.cat([emb, class_emb], dim=-1) |
|
else: |
|
emb = emb + class_emb |
|
|
|
if self.inner_unet.model_type == "unet": |
|
added_cond_kwargs = added_cond_kwargs or {} |
|
added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention |
|
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale |
|
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask |
|
|
|
if not self.config.nesting: |
|
encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( |
|
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed( |
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention |
|
aug_emb, __, _ = self.get_aug_embed( |
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
else: |
|
aug_emb, cond_mask, _ = self.get_aug_embed( |
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
elif self.inner_unet.model_type == "nested_unet": |
|
added_cond_kwargs = added_cond_kwargs or {} |
|
added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention |
|
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale |
|
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask |
|
|
|
encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states( |
|
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed( |
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
aug_emb, __, _ = self.get_aug_embed( |
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
|
|
if encoder_attention_mask is not None: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
if self.config.addition_embed_type == "image_hint": |
|
aug_emb, hint = aug_emb |
|
sample = torch.cat([sample, hint], dim=1) |
|
|
|
emb = emb + aug_emb + cond_emb if aug_emb is not None else emb |
|
|
|
if self.time_embed_act is not None: |
|
emb = self.time_embed_act(emb) |
|
|
|
if not self.config.skip_normalization: |
|
sample = sample / sample.std((1, 2, 3), keepdims=True) |
|
if isinstance(sample, list) and len(sample) == 1: |
|
sample = sample[0] |
|
sample = self.conv_in(sample) |
|
if self.config.nesting: |
|
sample = sample + sample_feat |
|
|
|
|
|
|
|
if cross_attention_kwargs is not None: |
|
cross_attention_kwargs = cross_attention_kwargs.copy() |
|
lora_scale = cross_attention_kwargs.pop("scale", 1.0) |
|
else: |
|
lora_scale = 1.0 |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
scale_lora_layers(self, lora_scale) |
|
|
|
|
|
is_adapter = down_intrablock_additional_residuals is not None |
|
|
|
|
|
|
|
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: |
|
deprecate( |
|
"T2I should not use down_block_additional_residuals", |
|
"1.3.0", |
|
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ |
|
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ |
|
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", |
|
standard_warn=False, |
|
) |
|
down_intrablock_additional_residuals = down_block_additional_residuals |
|
is_adapter = True |
|
|
|
|
|
down_block_res_samples = (sample,) |
|
for downsample_block in self.down_blocks: |
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
|
|
|
additional_residuals = {} |
|
if is_adapter and len(down_intrablock_additional_residuals) > 0: |
|
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) |
|
|
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, |
|
temb=emb[:bh], |
|
encoder_hidden_states=encoder_hidden_states[:bh], |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, |
|
**additional_residuals, |
|
) |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
if is_adapter and len(down_intrablock_additional_residuals) > 0: |
|
sample += down_intrablock_additional_residuals.pop(0) |
|
|
|
down_block_res_samples += res_samples |
|
|
|
|
|
x_inner = self.in_adapter(sample) if self.in_adapter is not None else None |
|
x_inner = ( |
|
torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner |
|
) |
|
inner_unet_output = self.inner_unet( |
|
(x_t_low, x_inner), |
|
timestep, |
|
cond_emb=cond_emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=cond_mask, |
|
from_nested=True, |
|
) |
|
x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner |
|
x_inner = self.out_adapter(x_inner) |
|
sample = sample + x_inner[:bh] if bh < bl else sample + x_inner |
|
|
|
|
|
for i, upsample_block in enumerate(self.up_blocks): |
|
is_final_block = i == len(self.up_blocks) - 1 |
|
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
|
|
|
|
if not is_final_block and forward_upsample_size: |
|
upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb[:bh], |
|
res_hidden_states_tuple=res_samples, |
|
encoder_hidden_states=encoder_hidden_states[:bh], |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
upsample_size=upsample_size, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, |
|
) |
|
else: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
upsample_size=upsample_size, |
|
) |
|
|
|
|
|
if self.conv_norm_out: |
|
sample_out = self.conv_norm_out(sample) |
|
sample_out = self.conv_act(sample_out) |
|
sample_out = self.conv_out(sample_out) |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
|
|
if isinstance(x_low, list): |
|
out = [sample_out] + x_low |
|
else: |
|
out = [sample_out, x_low] |
|
if self.config.nesting: |
|
return NestedUNet2DConditionOutput(sample=out, sample_inner=sample) |
|
if not return_dict: |
|
return (out,) |
|
else: |
|
return NestedUNet2DConditionOutput(sample=out) |
|
|
|
|
|
@dataclass |
|
class MatryoshkaPipelineOutput(BaseOutput): |
|
""" |
|
Output class for Matryoshka pipelines. |
|
|
|
Args: |
|
images (`List[PIL.Image.Image]` or `np.ndarray`) |
|
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, |
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. |
|
""" |
|
|
|
images: Union[List[Image.Image], List[List[Image.Image]], np.ndarray, List[np.ndarray]] |
|
|
|
|
|
class MatryoshkaPipeline( |
|
DiffusionPipeline, |
|
StableDiffusionMixin, |
|
TextualInversionLoaderMixin, |
|
StableDiffusionLoraLoaderMixin, |
|
IPAdapterMixin, |
|
FromSingleFileMixin, |
|
): |
|
r""" |
|
Pipeline for text-to-image generation using Matryoshka Diffusion Models. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods |
|
implemented for all pipelines (downloading, saving, running on a particular device, etc.). |
|
|
|
The pipeline also inherits the following loading methods: |
|
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings |
|
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights |
|
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights |
|
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files |
|
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters |
|
|
|
Args: |
|
text_encoder ([`~transformers.T5EncoderModel`]): |
|
Frozen text-encoder ([flan-t5-xl](https://huggingface.co/google/flan-t5-xl)). |
|
tokenizer ([`~transformers.T5Tokenizer`]): |
|
A `T5Tokenizer` to tokenize text. |
|
unet ([`MatryoshkaUNet2DConditionModel`]): |
|
A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
|
[`MatryoshkaDDIMScheduler`] and other schedulers with proper modifications, see an example usage in README.md. |
|
feature_extractor ([`~transformers.<AnImageProcessor>`]): |
|
A `AnImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. |
|
""" |
|
|
|
model_cpu_offload_seq = "text_encoder->image_encoder->unet" |
|
_optional_components = ["unet", "feature_extractor", "image_encoder"] |
|
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] |
|
|
|
def __init__( |
|
self, |
|
text_encoder: T5EncoderModel, |
|
tokenizer: T5TokenizerFast, |
|
scheduler: MatryoshkaDDIMScheduler, |
|
unet: MatryoshkaUNet2DConditionModel = None, |
|
feature_extractor: CLIPImageProcessor = None, |
|
image_encoder: CLIPVisionModelWithProjection = None, |
|
trust_remote_code: bool = False, |
|
nesting_level: int = 0, |
|
): |
|
super().__init__() |
|
|
|
if nesting_level == 0: |
|
unet = MatryoshkaUNet2DConditionModel.from_pretrained( |
|
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" |
|
) |
|
elif nesting_level == 1: |
|
unet = NestedUNet2DConditionModel.from_pretrained( |
|
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" |
|
) |
|
elif nesting_level == 2: |
|
unet = NestedUNet2DConditionModel.from_pretrained( |
|
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" |
|
) |
|
else: |
|
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") |
|
|
|
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: |
|
deprecation_message = ( |
|
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
|
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
|
"to update the config accordingly as leaving `steps_offset` might led to incorrect results" |
|
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," |
|
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" |
|
" file" |
|
) |
|
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) |
|
new_config = dict(scheduler.config) |
|
new_config["steps_offset"] = 1 |
|
scheduler._internal_dict = FrozenDict(new_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( |
|
version.parse(unet.config._diffusers_version).base_version |
|
) < version.parse("0.9.0.dev0") |
|
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 |
|
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: |
|
deprecation_message = ( |
|
"The configuration file of the unet has set the default `sample_size` to smaller than" |
|
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" |
|
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" |
|
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" |
|
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" |
|
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" |
|
" in the config might lead to incorrect results in future versions. If you have downloaded this" |
|
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" |
|
" the `unet/config.json` file" |
|
) |
|
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) |
|
new_config = dict(unet.config) |
|
new_config["sample_size"] = 64 |
|
unet._internal_dict = FrozenDict(new_config) |
|
|
|
if hasattr(unet, "nest_ratio"): |
|
scheduler.scales = unet.nest_ratio + [1] |
|
if nesting_level == 2: |
|
scheduler.schedule_shifted_power = 2.0 |
|
|
|
self.register_modules( |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
feature_extractor=feature_extractor, |
|
image_encoder=image_encoder, |
|
) |
|
self.register_to_config(nesting_level=nesting_level) |
|
self.image_processor = VaeImageProcessor(do_resize=False) |
|
|
|
def change_nesting_level(self, nesting_level: int): |
|
if nesting_level == 0: |
|
if hasattr(self.unet, "nest_ratio"): |
|
self.scheduler.scales = None |
|
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained( |
|
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" |
|
).to(self.device) |
|
self.config.nesting_level = 0 |
|
elif nesting_level == 1: |
|
self.unet = NestedUNet2DConditionModel.from_pretrained( |
|
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" |
|
).to(self.device) |
|
self.config.nesting_level = 1 |
|
self.scheduler.scales = self.unet.nest_ratio + [1] |
|
self.scheduler.schedule_shifted_power = 1.0 |
|
elif nesting_level == 2: |
|
self.unet = NestedUNet2DConditionModel.from_pretrained( |
|
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" |
|
).to(self.device) |
|
self.config.nesting_level = 2 |
|
self.scheduler.scales = self.unet.nest_ratio + [1] |
|
self.scheduler.schedule_shifted_power = 2.0 |
|
else: |
|
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def encode_prompt( |
|
self, |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt=None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
lora_scale: Optional[float] = None, |
|
clip_skip: Optional[int] = None, |
|
): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
prompt to be encoded |
|
device: (`torch.device`): |
|
torch device |
|
num_images_per_prompt (`int`): |
|
number of images that should be generated per prompt |
|
do_classifier_free_guidance (`bool`): |
|
whether to use classifier free guidance or not |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
lora_scale (`float`, *optional*): |
|
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
|
clip_skip (`int`, *optional*): |
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
|
the output of the pre-final layer will be used for computing the prompt embeddings. |
|
""" |
|
|
|
|
|
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): |
|
self._lora_scale = lora_scale |
|
|
|
|
|
if not USE_PEFT_BACKEND: |
|
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) |
|
else: |
|
scale_lora_layers(self.text_encoder, lora_scale) |
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
if prompt_embeds is None: |
|
|
|
if isinstance(self, TextualInversionLoaderMixin): |
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer) |
|
|
|
text_inputs = self.tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
|
text_input_ids, untruncated_ids |
|
): |
|
removed_text = self.tokenizer.batch_decode( |
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] |
|
) |
|
logger.warning( |
|
"The following part of your input was truncated because FLAN-T5-XL for this pipeline can only handle sequences up to" |
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
|
) |
|
|
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
|
prompt_attention_mask = text_inputs.attention_mask.to(device) |
|
else: |
|
prompt_attention_mask = None |
|
|
|
if self.text_encoder is not None: |
|
prompt_embeds_dtype = self.text_encoder.dtype |
|
elif self.unet is not None: |
|
prompt_embeds_dtype = self.unet.dtype |
|
else: |
|
prompt_embeds_dtype = prompt_embeds.dtype |
|
|
|
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
uncond_tokens: List[str] |
|
if negative_prompt is None: |
|
uncond_tokens = [""] * batch_size |
|
elif prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif isinstance(negative_prompt, str): |
|
uncond_tokens = [negative_prompt] |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
else: |
|
uncond_tokens = negative_prompt |
|
|
|
|
|
if isinstance(self, TextualInversionLoaderMixin): |
|
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) |
|
|
|
uncond_input = self.tokenizer( |
|
uncond_tokens, |
|
return_tensors="pt", |
|
) |
|
uncond_input_ids = uncond_input.input_ids |
|
|
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
|
negative_prompt_attention_mask = uncond_input.attention_mask.to(device) |
|
else: |
|
negative_prompt_attention_mask = None |
|
|
|
if not do_classifier_free_guidance: |
|
if clip_skip is None: |
|
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) |
|
prompt_embeds = prompt_embeds[0] |
|
else: |
|
prompt_embeds = self.text_encoder( |
|
text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True |
|
) |
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] |
|
|
|
|
|
|
|
|
|
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) |
|
else: |
|
max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0])) |
|
if len(text_input_ids[0]) < max_len: |
|
text_input_ids = torch.cat( |
|
[text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)], |
|
dim=1, |
|
) |
|
prompt_attention_mask = torch.cat( |
|
[ |
|
prompt_attention_mask, |
|
torch.zeros( |
|
batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device |
|
), |
|
], |
|
dim=1, |
|
) |
|
elif len(uncond_input_ids[0]) < max_len: |
|
uncond_input_ids = torch.cat( |
|
[uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)], |
|
dim=1, |
|
) |
|
negative_prompt_attention_mask = torch.cat( |
|
[ |
|
negative_prompt_attention_mask, |
|
torch.zeros( |
|
batch_size, |
|
max_len - len(negative_prompt_attention_mask[0]), |
|
dtype=torch.long, |
|
device=device, |
|
), |
|
], |
|
dim=1, |
|
) |
|
cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0) |
|
cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
|
prompt_embeds = self.text_encoder( |
|
cfg_input_ids.to(device), |
|
attention_mask=cfg_attention_mask, |
|
) |
|
prompt_embeds = prompt_embeds[0] |
|
|
|
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) |
|
|
|
if self.text_encoder is not None: |
|
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self.text_encoder, lora_scale) |
|
|
|
if not do_classifier_free_guidance: |
|
return prompt_embeds, None, prompt_attention_mask, None |
|
return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask |
|
|
|
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): |
|
dtype = next(self.image_encoder.parameters()).dtype |
|
|
|
if not isinstance(image, torch.Tensor): |
|
image = self.feature_extractor(image, return_tensors="pt").pixel_values |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
if output_hidden_states: |
|
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] |
|
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) |
|
uncond_image_enc_hidden_states = self.image_encoder( |
|
torch.zeros_like(image), output_hidden_states=True |
|
).hidden_states[-2] |
|
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( |
|
num_images_per_prompt, dim=0 |
|
) |
|
return image_enc_hidden_states, uncond_image_enc_hidden_states |
|
else: |
|
image_embeds = self.image_encoder(image).image_embeds |
|
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) |
|
uncond_image_embeds = torch.zeros_like(image_embeds) |
|
|
|
return image_embeds, uncond_image_embeds |
|
|
|
def prepare_ip_adapter_image_embeds( |
|
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance |
|
): |
|
image_embeds = [] |
|
if do_classifier_free_guidance: |
|
negative_image_embeds = [] |
|
if ip_adapter_image_embeds is None: |
|
if not isinstance(ip_adapter_image, list): |
|
ip_adapter_image = [ip_adapter_image] |
|
|
|
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): |
|
raise ValueError( |
|
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." |
|
) |
|
|
|
for single_ip_adapter_image, image_proj_layer in zip( |
|
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers |
|
): |
|
output_hidden_state = not isinstance(image_proj_layer, ImageProjection) |
|
single_image_embeds, single_negative_image_embeds = self.encode_image( |
|
single_ip_adapter_image, device, 1, output_hidden_state |
|
) |
|
|
|
image_embeds.append(single_image_embeds[None, :]) |
|
if do_classifier_free_guidance: |
|
negative_image_embeds.append(single_negative_image_embeds[None, :]) |
|
else: |
|
for single_image_embeds in ip_adapter_image_embeds: |
|
if do_classifier_free_guidance: |
|
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) |
|
negative_image_embeds.append(single_negative_image_embeds) |
|
image_embeds.append(single_image_embeds) |
|
|
|
ip_adapter_image_embeds = [] |
|
for i, single_image_embeds in enumerate(image_embeds): |
|
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) |
|
if do_classifier_free_guidance: |
|
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) |
|
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) |
|
|
|
single_image_embeds = single_image_embeds.to(device=device) |
|
ip_adapter_image_embeds.append(single_image_embeds) |
|
|
|
return ip_adapter_image_embeds |
|
|
|
def prepare_extra_step_kwargs(self, generator, eta): |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: |
|
extra_step_kwargs["generator"] = generator |
|
return extra_step_kwargs |
|
|
|
def check_inputs( |
|
self, |
|
prompt, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt=None, |
|
prompt_embeds=None, |
|
negative_prompt_embeds=None, |
|
ip_adapter_image=None, |
|
ip_adapter_image_embeds=None, |
|
callback_on_step_end_tensor_inputs=None, |
|
): |
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
|
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
if callback_on_step_end_tensor_inputs is not None and not all( |
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
|
): |
|
raise ValueError( |
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" |
|
) |
|
|
|
if prompt is not None and prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
|
" only forward one of the two." |
|
) |
|
elif prompt is None and prompt_embeds is None: |
|
raise ValueError( |
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
|
) |
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
|
if negative_prompt is not None and negative_prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
|
) |
|
|
|
if prompt_embeds is not None and negative_prompt_embeds is not None: |
|
if prompt_embeds.shape != negative_prompt_embeds.shape: |
|
raise ValueError( |
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" |
|
f" {negative_prompt_embeds.shape}." |
|
) |
|
|
|
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: |
|
raise ValueError( |
|
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." |
|
) |
|
|
|
if ip_adapter_image_embeds is not None: |
|
if not isinstance(ip_adapter_image_embeds, list): |
|
raise ValueError( |
|
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" |
|
) |
|
elif ip_adapter_image_embeds[0].ndim not in [3, 4]: |
|
raise ValueError( |
|
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" |
|
) |
|
|
|
def prepare_latents( |
|
self, batch_size, num_channels_latents, height, width, dtype, device, generator, scales, latents=None |
|
): |
|
shape = ( |
|
batch_size, |
|
num_channels_latents, |
|
int(height), |
|
int(width), |
|
) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
if scales is not None: |
|
out = [latents] |
|
for s in scales[1:]: |
|
ratio = scales[0] // s |
|
sample_low = F.avg_pool2d(latents, ratio) * ratio |
|
sample_low = sample_low.normal_(generator=generator) |
|
out += [sample_low] |
|
latents = out |
|
else: |
|
if scales is not None: |
|
latents = [latent.to(device=device) for latent in latents] |
|
else: |
|
latents = latents.to(device) |
|
|
|
|
|
if scales is not None: |
|
latents = [latent * self.scheduler.init_noise_sigma for latent in latents] |
|
else: |
|
latents = latents * self.scheduler.init_noise_sigma |
|
return latents |
|
|
|
|
|
def get_guidance_scale_embedding( |
|
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 |
|
) -> torch.Tensor: |
|
""" |
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
|
|
|
Args: |
|
w (`torch.Tensor`): |
|
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. |
|
embedding_dim (`int`, *optional*, defaults to 512): |
|
Dimension of the embeddings to generate. |
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): |
|
Data type of the generated embeddings. |
|
|
|
Returns: |
|
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. |
|
""" |
|
assert len(w.shape) == 1 |
|
w = w * 1000.0 |
|
|
|
half_dim = embedding_dim // 2 |
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
|
emb = w.to(dtype)[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1)) |
|
assert emb.shape == (w.shape[0], embedding_dim) |
|
return emb |
|
|
|
@property |
|
def guidance_scale(self): |
|
return self._guidance_scale |
|
|
|
@property |
|
def guidance_rescale(self): |
|
return self._guidance_rescale |
|
|
|
@property |
|
def clip_skip(self): |
|
return self._clip_skip |
|
|
|
|
|
|
|
|
|
@property |
|
def do_classifier_free_guidance(self): |
|
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None |
|
|
|
@property |
|
def cross_attention_kwargs(self): |
|
return self._cross_attention_kwargs |
|
|
|
@property |
|
def num_timesteps(self): |
|
return self._num_timesteps |
|
|
|
@property |
|
def interrupt(self): |
|
return self._interrupt |
|
|
|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
timesteps: List[int] = None, |
|
sigmas: List[float] = None, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.Tensor] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
ip_adapter_image: Optional[PipelineImageInput] = None, |
|
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.0, |
|
clip_skip: Optional[int] = None, |
|
callback_on_step_end: Optional[ |
|
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] |
|
] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
**kwargs, |
|
): |
|
r""" |
|
The call function to the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
|
height (`int`, *optional*, defaults to `self.unet.config.sample_size`): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to `self.unet.config.sample_size`): |
|
The width in pixels of the generated image. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument |
|
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is |
|
passed will be used. Must be in descending order. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in |
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed |
|
will be used. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
A higher guidance scale value encourages the model to generate images closely linked to the text |
|
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
|
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies |
|
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
|
generation deterministic. |
|
latents (`torch.Tensor`, *optional*): |
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor is generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
|
provided, text embeddings are generated from the `prompt` input argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
|
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
|
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. |
|
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): |
|
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of |
|
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should |
|
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not |
|
provided, embeddings are computed from the `ip_adapter_image` input argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
guidance_rescale (`float`, *optional*, defaults to 0.0): |
|
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are |
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when |
|
using zero terminal SNR. |
|
clip_skip (`int`, *optional*): |
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
|
the output of the pre-final layer will be used for computing the prompt embeddings. |
|
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): |
|
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of |
|
each denoising step during the inference. with the following arguments: `callback_on_step_end(self: |
|
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a |
|
list of all tensors as specified by `callback_on_step_end_tensor_inputs`. |
|
callback_on_step_end_tensor_inputs (`List`, *optional*): |
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list |
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the |
|
`._callback_tensor_inputs` attribute of your pipeline class. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~MatryoshkaPipelineOutput`] or `tuple`: |
|
If `return_dict` is `True`, [`~MatryoshkaPipelineOutput`] is returned, |
|
otherwise a `tuple` is returned where the first element is a list with the generated images and the |
|
second element is a list of `bool`s indicating whether the corresponding generated image contains |
|
"not-safe-for-work" (nsfw) content. |
|
""" |
|
|
|
callback = kwargs.pop("callback", None) |
|
callback_steps = kwargs.pop("callback_steps", None) |
|
|
|
if callback is not None: |
|
deprecate( |
|
"callback", |
|
"1.0.0", |
|
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", |
|
) |
|
if callback_steps is not None: |
|
deprecate( |
|
"callback_steps", |
|
"1.0.0", |
|
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", |
|
) |
|
|
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): |
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs |
|
|
|
|
|
height = height or self.unet.config.sample_size |
|
width = width or self.unet.config.sample_size |
|
|
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
ip_adapter_image, |
|
ip_adapter_image_embeds, |
|
callback_on_step_end_tensor_inputs, |
|
) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._guidance_rescale = guidance_rescale |
|
self._clip_skip = clip_skip |
|
self._cross_attention_kwargs = cross_attention_kwargs |
|
self._interrupt = False |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
lora_scale = ( |
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None |
|
) |
|
|
|
( |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
prompt_attention_mask, |
|
negative_prompt_attention_mask, |
|
) = self.encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
self.do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
lora_scale=lora_scale, |
|
clip_skip=self.clip_skip, |
|
) |
|
|
|
|
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)]) |
|
attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) |
|
else: |
|
attention_masks = prompt_attention_mask |
|
|
|
prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1) |
|
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: |
|
image_embeds = self.prepare_ip_adapter_image_embeds( |
|
ip_adapter_image, |
|
ip_adapter_image_embeds, |
|
device, |
|
batch_size * num_images_per_prompt, |
|
self.do_classifier_free_guidance, |
|
) |
|
|
|
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, num_inference_steps, device, timesteps, sigmas |
|
) |
|
timesteps = timesteps[:-1] |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
self.scheduler.scales, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
extra_step_kwargs |= {"use_clipped_model_output": True} |
|
|
|
|
|
added_cond_kwargs = ( |
|
{"image_embeds": image_embeds} |
|
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) |
|
else None |
|
) |
|
|
|
|
|
timestep_cond = None |
|
if self.unet.config.time_cond_proj_dim is not None: |
|
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) |
|
timestep_cond = self.get_guidance_scale_embedding( |
|
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim |
|
).to(device=device, dtype=latents.dtype) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
self._num_timesteps = len(timesteps) |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if self.interrupt: |
|
continue |
|
|
|
|
|
if self.do_classifier_free_guidance and isinstance(latents, list): |
|
latent_model_input = [latent.repeat(2, 1, 1, 1) for latent in latents] |
|
elif self.do_classifier_free_guidance: |
|
latent_model_input = latents.repeat(2, 1, 1, 1) |
|
else: |
|
latent_model_input = latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t - 1, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep_cond=timestep_cond, |
|
cross_attention_kwargs=self.cross_attention_kwargs, |
|
added_cond_kwargs=added_cond_kwargs, |
|
encoder_attention_mask=attention_masks, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if isinstance(noise_pred, list) and self.do_classifier_free_guidance: |
|
for i, (noise_pred_uncond, noise_pred_text) in enumerate(noise_pred): |
|
noise_pred[i] = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
elif self.do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
|
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, t, latents) |
|
|
|
if XLA_AVAILABLE: |
|
xm.mark_step() |
|
|
|
image = latents |
|
|
|
if self.scheduler.scales is not None: |
|
for i, img in enumerate(image): |
|
image[i] = self.image_processor.postprocess(img, output_type=output_type)[0] |
|
else: |
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return MatryoshkaPipelineOutput(images=image) |
|
|