import glob import torch.nn.functional as F from pathlib import Path from PIL import Image import torch import yaml import torchvision.transforms as T from import read_video, write_video import os import random import numpy as np import logging logger = logging.getLogger(__name__) # Modified from def register_time(model, t): conv_module = model.unet.up_blocks[1].resnets[1] setattr(conv_module, "t", t) up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} for res in up_res_dict: for block in up_res_dict[res]: module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1.processor setattr(module, "t", t) module = model.unet.up_blocks[res].temp_attentions[block].transformer_blocks[0].attn1.processor setattr(module, "t", t) # PNP injection functions # Modified from ResnetBlock2D.forward # Modified from models/ from diffusers.utils import USE_PEFT_BACKEND from diffusers.models.upsampling import Upsample2D from diffusers.models.downsampling import Downsample2D def register_conv_injection(model, injection_schedule): def conv_forward(self): def forward( input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0, ) -> torch.FloatTensor: hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = ( self.upsample(input_tensor, scale=scale) if isinstance(self.upsample, Upsample2D) else self.upsample(input_tensor) ) hidden_states = ( self.upsample(hidden_states, scale=scale) if isinstance(self.upsample, Upsample2D) else self.upsample(hidden_states) ) elif self.downsample is not None: input_tensor = ( self.downsample(input_tensor, scale=scale) if isinstance(self.downsample, Downsample2D) else self.downsample(input_tensor) ) hidden_states = ( self.downsample(hidden_states, scale=scale) if isinstance(self.downsample, Downsample2D) else self.downsample(hidden_states) ) hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = ( self.time_emb_proj(temb, scale)[:, :, None, None] if not USE_PEFT_BACKEND else self.time_emb_proj(temb)[:, :, None, None] ) if self.time_embedding_norm == "default": if temb is not None: hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) elif self.time_embedding_norm == "scale_shift": if temb is None: raise ValueError( f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" ) time_scale, time_shift = torch.chunk(temb, 2, dim=1) hidden_states = self.norm2(hidden_states) hidden_states = hidden_states * (1 + time_scale) + time_shift else: hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000): logger.debug(f"PnP Injecting Conv at t={self.t}") source_batch_size = int(hidden_states.shape[0] // 3) # inject unconditional hidden_states[source_batch_size : 2 * source_batch_size] = hidden_states[:source_batch_size] # inject conditional hidden_states[2 * source_batch_size :] = hidden_states[:source_batch_size] if self.conv_shortcut is not None: input_tensor = ( self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) ) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor return forward conv_module = model.unet.up_blocks[1].resnets[1] conv_module.forward = conv_forward(conv_module) setattr(conv_module, "injection_schedule", injection_schedule) # Modified from AttnProcessor2_0.__call__ # Modified from models/ from typing import Optional from diffusers.models.attention_processor import AttnProcessor2_0 def register_spatial_attention_pnp(model, injection_schedule): class ModifiedSpaAttnProcessor(AttnProcessor2_0): def __call__( self, attn, # attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) # Modified here chunk_size = batch_size // 3 # batch_size is 3*chunk_size because concat[source, uncond, cond] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) # Modified here. if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000): logger.debug(f"PnP Injecting Spa-Attn at t={self.t}") # inject source into unconditional query[chunk_size : 2 * chunk_size] = query[:chunk_size] key[chunk_size : 2 * chunk_size] = key[:chunk_size] # inject source into conditional query[2 * chunk_size :] = query[:chunk_size] key[2 * chunk_size :] = key[:chunk_size] inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states # for _, module in model.unet.named_modules(): # if isinstance_str(module, "BasicTransformerBlock"): # module.attn1.processor.__call__ = sa_processor__call__(module.attn1.processor) # setattr(module.attn1.processor, "injection_schedule", []) # Disable PNP res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution for res in res_dict: for block in res_dict[res]: module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1 modified_processor = ModifiedSpaAttnProcessor() setattr(modified_processor, "injection_schedule", injection_schedule) module.processor = modified_processor def register_temp_attention_pnp(model, injection_schedule): class ModifiedTmpAttnProcessor(AttnProcessor2_0): def __call__( self, attn, # attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) # Modified here chunk_size = batch_size // 3 # batch_size is 3*chunk_size because concat[source, uncond, cond] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) # Modified here. if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000): logger.debug(f"PnP Injecting Tmp-Attn at t={self.t}") # inject source into unconditional query[chunk_size : 2 * chunk_size] = query[:chunk_size] key[chunk_size : 2 * chunk_size] = key[:chunk_size] # inject source into conditional query[2 * chunk_size :] = query[:chunk_size] key[2 * chunk_size :] = key[:chunk_size] inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states # for _, module in model.unet.named_modules(): # if isinstance_str(module, "BasicTransformerBlock"): # module.attn1.processor.__call__ = ta_processor__call__(module.attn1.processor) # setattr(module.attn1.processor, "injection_schedule", []) # Disable PNP res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution for res in res_dict: for block in res_dict[res]: module = model.unet.up_blocks[res].temp_attentions[block].transformer_blocks[0].attn1 modified_processor = ModifiedTmpAttnProcessor() setattr(modified_processor, "injection_schedule", injection_schedule) module.processor = modified_processor