from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union, Any, Callable import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput, logging from diffusers.utils.torch_utils import is_torch_version from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unets.unet_3d_blocks import ( UNetMidBlockSpatioTemporal, get_down_block as gdb, get_up_block as gub, ) from diffusers.models.resnet import ( Downsample2D, SpatioTemporalResBlock, Upsample2D, ) from diffusers.models.transformers.transformer_temporal import TransformerSpatioTemporalModel from diffusers.models.attention_processor import Attention from diffusers.utils import deprecate from diffusers.utils.import_utils import is_xformers_available from network_utils import DragEmbedding, get_2d_sincos_pos_embed logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_xformers_available(): import xformers import xformers.ops class AllToFirstXFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, *args, **kwargs, ) -> torch.FloatTensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) assert encoder_hidden_states is None attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states.view(-1, 14, *hidden_states.shape[1:])[:, 0])[:, None].expand(-1, 14, -1, -1).flatten(0, 1) value = attn.to_v(hidden_states.view(-1, 14, *hidden_states.shape[1:])[:, 0])[:, None].expand(-1, 14, -1, -1).flatten(0, 1) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CrossAttnDownBlockSpatioTemporalWithFlow(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, flow_channels: int, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: int = 1, cross_attention_dim: int = 1280, add_downsample: bool = True, num_frames: int = 14, pos_embed_dim: int = 64, drag_token_cross_attn: bool = True, use_modulate: bool = True, drag_embedder_out_channels = (256, 320, 320), num_max_drags: int = 5, ): super().__init__() resnets = [] attentions = [] flow_convs = [] if drag_token_cross_attn: drag_token_mlps = [] self.num_max_drags = num_max_drags self.num_frames = num_frames self.pos_embed_dim = pos_embed_dim self.drag_token_cross_attn = drag_token_cross_attn self.has_cross_attention = True self.num_attention_heads = num_attention_heads self.use_modulate = use_modulate if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=1e-6, ) ) attentions.append( TransformerSpatioTemporalModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, ) ) flow_convs.append( DragEmbedding( conditioning_channels=flow_channels, conditioning_embedding_channels=out_channels * 2 if use_modulate else out_channels, block_out_channels = drag_embedder_out_channels, ) ) if drag_token_cross_attn: drag_token_mlps.append( nn.Sequential( nn.Linear(pos_embed_dim * 2 + out_channels * 2, cross_attention_dim), nn.SiLU(), nn.Linear(cross_attention_dim, cross_attention_dim), ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.flow_convs = nn.ModuleList(flow_convs) if drag_token_cross_attn: self.drag_token_mlps = nn.ModuleList(drag_token_mlps) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=1, name="op", ) ] ) else: self.downsamplers = None self.pos_embedding = {res: torch.tensor(get_2d_sincos_pos_embed(self.pos_embed_dim, res)) for res in [32, 16, 8, 4, 2]} self.pos_embedding_prepared = False self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, image_only_indicator: Optional[torch.Tensor] = None, flow: Optional[torch.Tensor] = None, drag_original: Optional[torch.Tensor] = None, # (batch_frame, num_points, 4) ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () batch_frame = hidden_states.shape[0] if self.drag_token_cross_attn: encoder_hidden_states_ori = encoder_hidden_states if not self.pos_embedding_prepared: for res in self.pos_embedding: self.pos_embedding[res] = self.pos_embedding[res].to(hidden_states) self.pos_embedding_prepared = True blocks = list(zip(self.resnets, self.attentions, self.flow_convs)) for bid, (resnet, attn, flow_conv) in enumerate(blocks): if self.training and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, **ckpt_kwargs, ) if flow is not None: # flow shape is (batch_frame, 40, h, w) drags = flow.view(-1, self.num_frames, *flow.shape[1:]) drags = drags.chunk(self.num_max_drags, dim=2) # (batch, frame, 4, h, w) x 10 drags = torch.stack(drags, dim=0) # 10, batch, frame, 4, h, w invalid_flag = torch.all(drags == -1, dim=(2, 3, 4, 5)) if self.use_modulate: scale, shift = flow_conv(flow).chunk(2, dim=1) else: scale = 0 shift = flow_conv(flow) hidden_states = hidden_states * (1 + scale) + shift # print(self.drag_token_cross_attn) if self.drag_token_cross_attn: drag_token_mlp = self.drag_token_mlps[bid] pos_embed = self.pos_embedding[scale.shape[-1]] pos_embed = pos_embed.reshape(1, scale.shape[-1], scale.shape[-1], -1).permute(0, 3, 1, 2) grid = (drag_original[..., :2] * 2 - 1)[:, None] grid_end = (drag_original[..., 2:] * 2 - 1)[:, None] drags_pos_start = F.grid_sample(pos_embed.repeat(batch_frame, 1, 1, 1), grid, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) drags_pos_end = F.grid_sample(pos_embed.repeat(batch_frame, 1, 1, 1), grid_end, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) features = F.grid_sample(hidden_states.detach(), grid, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) features_end = F.grid_sample(hidden_states.detach(), grid_end, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) drag_token_in = torch.cat([features, features_end, drags_pos_start, drags_pos_end], dim=1).permute(0, 2, 1) drag_token_out = drag_token_mlp(drag_token_in) # Mask the invalid drags drag_token_out = drag_token_out.view(batch_frame // self.num_frames, self.num_frames, self.num_max_drags, -1) drag_token_out = drag_token_out.permute(2, 0, 1, 3) drag_token_out = drag_token_out.masked_fill(invalid_flag[..., None, None].expand_as(drag_token_out), 0) drag_token_out = drag_token_out.permute(1, 2, 0, 3).flatten(0, 1) encoder_hidden_states = torch.cat([encoder_hidden_states_ori, drag_token_out], dim=1) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] else: hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) if flow is not None: # flow shape is (batch_frame, 40, h, w) drags = flow.view(-1, self.num_frames, *flow.shape[1:]) drags = drags.chunk(self.num_max_drags, dim=2) # (batch, frame, 4, h, w) x 10 drags = torch.stack(drags, dim=0) # 10, batch, frame, 4, h, w invalid_flag = torch.all(drags == -1, dim=(2, 3, 4, 5)) if self.use_modulate: scale, shift = flow_conv(flow).chunk(2, dim=1) else: scale = 0 shift = flow_conv(flow) hidden_states = hidden_states * (1 + scale) + shift if self.drag_token_cross_attn: drag_token_mlp = self.drag_token_mlps[bid] pos_embed = self.pos_embedding[scale.shape[-1]] pos_embed = pos_embed.reshape(1, scale.shape[-1], scale.shape[-1], -1).permute(0, 3, 1, 2) grid = (drag_original[..., :2] * 2 - 1)[:, None] grid_end = (drag_original[..., 2:] * 2 - 1)[:, None] drags_pos_start = F.grid_sample(pos_embed.repeat(batch_frame, 1, 1, 1), grid, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) drags_pos_end = F.grid_sample(pos_embed.repeat(batch_frame, 1, 1, 1), grid_end, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) features = F.grid_sample(hidden_states.detach(), grid, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) features_end = F.grid_sample(hidden_states.detach(), grid_end, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) drag_token_in = torch.cat([features, features_end, drags_pos_start, drags_pos_end], dim=1).permute(0, 2, 1) drag_token_out = drag_token_mlp(drag_token_in) # Mask the invalid drags drag_token_out = drag_token_out.view(batch_frame // self.num_frames, self.num_frames, self.num_max_drags, -1) drag_token_out = drag_token_out.permute(2, 0, 1, 3) drag_token_out = drag_token_out.masked_fill(invalid_flag[..., None, None].expand_as(drag_token_out), 0) drag_token_out = drag_token_out.permute(1, 2, 0, 3).flatten(0, 1) encoder_hidden_states = torch.cat([encoder_hidden_states_ori, drag_token_out], dim=1) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class CrossAttnUpBlockSpatioTemporalWithFlow(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, flow_channels: int, resolution_idx: Optional[int] = None, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, num_attention_heads: int = 1, cross_attention_dim: int = 1280, add_upsample: bool = True, num_frames: int = 14, pos_embed_dim: int = 64, drag_token_cross_attn: bool = True, use_modulate: bool = True, drag_embedder_out_channels = (256, 320, 320), num_max_drags: int = 5, ): super().__init__() resnets = [] attentions = [] flow_convs = [] if drag_token_cross_attn: drag_token_mlps = [] self.num_max_drags = num_max_drags self.drag_token_cross_attn = drag_token_cross_attn self.num_frames = num_frames self.pos_embed_dim = pos_embed_dim self.has_cross_attention = True self.num_attention_heads = num_attention_heads self.use_modulate = use_modulate if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, ) ) attentions.append( TransformerSpatioTemporalModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, ) ) flow_convs.append( DragEmbedding( conditioning_channels=flow_channels, conditioning_embedding_channels=out_channels * 2 if use_modulate else out_channels, block_out_channels = drag_embedder_out_channels, ) ) if drag_token_cross_attn: drag_token_mlps.append( nn.Sequential( nn.Linear(pos_embed_dim * 2 + out_channels * 2, cross_attention_dim), nn.SiLU(), nn.Linear(cross_attention_dim, cross_attention_dim), ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.flow_convs = nn.ModuleList(flow_convs) if drag_token_cross_attn: self.drag_token_mlps = nn.ModuleList(drag_token_mlps) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.pos_embedding = {res: torch.tensor(get_2d_sincos_pos_embed(pos_embed_dim, res)) for res in [32, 16, 8, 4, 2]} self.pos_embedding_prepared = False self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, image_only_indicator: Optional[torch.Tensor] = None, flow: Optional[torch.Tensor] = None, drag_original: Optional[torch.Tensor] = None, # (batch_frame, num_points, 4) ) -> torch.FloatTensor: batch_frame = hidden_states.shape[0] if self.drag_token_cross_attn: encoder_hidden_states_ori = encoder_hidden_states if not self.pos_embedding_prepared: for res in self.pos_embedding: self.pos_embedding[res] = self.pos_embedding[res].to(hidden_states) self.pos_embedding_prepared = True for bid, (resnet, attn, flow_conv) in enumerate(zip(self.resnets, self.attentions, self.flow_convs)): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, **ckpt_kwargs, ) if flow is not None: # flow shape is (batch_frame, 40, h, w) drags = flow.view(-1, self.num_frames, *flow.shape[1:]) drags = drags.chunk(self.num_max_drags, dim=2) # (batch, frame, 4, h, w) x 10 drags = torch.stack(drags, dim=0) # 10, batch, frame, 4, h, w invalid_flag = torch.all(drags == -1, dim=(2, 3, 4, 5)) if self.use_modulate: scale, shift = flow_conv(flow).chunk(2, dim=1) else: scale = 0 shift = flow_conv(flow) hidden_states = hidden_states * (1 + scale) + shift if self.drag_token_cross_attn: drag_token_mlp = self.drag_token_mlps[bid] pos_embed = self.pos_embedding[scale.shape[-1]] pos_embed = pos_embed.reshape(1, scale.shape[-1], scale.shape[-1], -1).permute(0, 3, 1, 2) grid = (drag_original[..., :2] * 2 - 1)[:, None] grid_end = (drag_original[..., 2:] * 2 - 1)[:, None] drags_pos_start = F.grid_sample(pos_embed.repeat(batch_frame, 1, 1, 1), grid, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) drags_pos_end = F.grid_sample(pos_embed.repeat(batch_frame, 1, 1, 1), grid_end, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) features = F.grid_sample(hidden_states.detach(), grid, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) features_end = F.grid_sample(hidden_states.detach(), grid_end, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) drag_token_in = torch.cat([features, features_end, drags_pos_start, drags_pos_end], dim=1).permute(0, 2, 1) drag_token_out = drag_token_mlp(drag_token_in) # Mask the invalid drags drag_token_out = drag_token_out.view(batch_frame // self.num_frames, self.num_frames, self.num_max_drags, -1) drag_token_out = drag_token_out.permute(2, 0, 1, 3) drag_token_out = drag_token_out.masked_fill(invalid_flag[..., None, None].expand_as(drag_token_out), 0) drag_token_out = drag_token_out.permute(1, 2, 0, 3).flatten(0, 1) encoder_hidden_states = torch.cat([encoder_hidden_states_ori, drag_token_out], dim=1) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] else: hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) if flow is not None: # flow shape is (batch_frame, 40, h, w) drags = flow.view(-1, self.num_frames, *flow.shape[1:]) drags = drags.chunk(self.num_max_drags, dim=2) # (batch, frame, 4, h, w) x 10 drags = torch.stack(drags, dim=0) # 10, batch, frame, 4, h, w invalid_flag = torch.all(drags == -1, dim=(2, 3, 4, 5)) if self.use_modulate: scale, shift = flow_conv(flow).chunk(2, dim=1) else: scale = 0 shift = flow_conv(flow) hidden_states = hidden_states * (1 + scale) + shift if self.drag_token_cross_attn: drag_token_mlp = self.drag_token_mlps[bid] pos_embed = self.pos_embedding[scale.shape[-1]] pos_embed = pos_embed.reshape(1, scale.shape[-1], scale.shape[-1], -1).permute(0, 3, 1, 2) grid = (drag_original[..., :2] * 2 - 1)[:, None] grid_end = (drag_original[..., 2:] * 2 - 1)[:, None] drags_pos_start = F.grid_sample(pos_embed.repeat(batch_frame, 1, 1, 1), grid, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) drags_pos_end = F.grid_sample(pos_embed.repeat(batch_frame, 1, 1, 1), grid_end, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) features = F.grid_sample(hidden_states.detach(), grid, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) features_end = F.grid_sample(hidden_states.detach(), grid_end, padding_mode="border", mode="bilinear", align_corners=False).squeeze(dim=2) drag_token_in = torch.cat([features, features_end, drags_pos_start, drags_pos_end], dim=1).permute(0, 2, 1) drag_token_out = drag_token_mlp(drag_token_in) # Mask the invalid drags drag_token_out = drag_token_out.view(batch_frame // self.num_frames, self.num_frames, self.num_max_drags, -1) drag_token_out = drag_token_out.permute(2, 0, 1, 3) drag_token_out = drag_token_out.masked_fill(invalid_flag[..., None, None].expand_as(drag_token_out), 0) drag_token_out = drag_token_out.permute(1, 2, 0, 3).flatten(0, 1) encoder_hidden_states = torch.cat([encoder_hidden_states_ori, drag_token_out], dim=1) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states def get_down_block( with_concatenated_flow: bool = False, *args, **kwargs, ): NEEDED_KEYS = [ "in_channels", "out_channels", "temb_channels", "flow_channels", "num_layers", "transformer_layers_per_block", "num_attention_heads", "cross_attention_dim", "add_downsample", "pos_embed_dim", 'use_modulate', "drag_token_cross_attn", "drag_embedder_out_channels", "num_max_drags", ] if not with_concatenated_flow or args[0] == "DownBlockSpatioTemporal": kwargs.pop("flow_channels", None) kwargs.pop("pos_embed_dim", None) kwargs.pop("use_modulate", None) kwargs.pop("drag_token_cross_attn", None) kwargs.pop("drag_embedder_out_channels", None) kwargs.pop("num_max_drags", None) return gdb(*args, **kwargs) elif args[0] == "CrossAttnDownBlockSpatioTemporal": for key in list(kwargs.keys()): if key not in NEEDED_KEYS: kwargs.pop(key, None) return CrossAttnDownBlockSpatioTemporalWithFlow(*args[1:], **kwargs) else: raise ValueError(f"Unknown block type {args[0]}") def get_up_block( with_concatenated_flow: bool = False, *args, **kwargs, ): NEEDED_KEYS = [ "in_channels", "out_channels", "prev_output_channel", "temb_channels", "flow_channels", "resolution_idx", "num_layers", "transformer_layers_per_block", "resnet_eps", "num_attention_heads", "cross_attention_dim", "add_upsample", "pos_embed_dim", "use_modulate", "drag_token_cross_attn", "drag_embedder_out_channels", "num_max_drags", ] if not with_concatenated_flow or args[0] == "UpBlockSpatioTemporal": kwargs.pop("flow_channels", None) kwargs.pop("pos_embed_dim", None) kwargs.pop("use_modulate", None) kwargs.pop("drag_token_cross_attn", None) kwargs.pop("drag_embedder_out_channels", None) kwargs.pop("num_max_drags", None) return gub(*args, **kwargs) elif args[0] == "CrossAttnUpBlockSpatioTemporal": for key in list(kwargs.keys()): if key not in NEEDED_KEYS: kwargs.pop(key, None) return CrossAttnUpBlockSpatioTemporalWithFlow(*args[1:], **kwargs) else: raise ValueError(f"Unknown block type {args[0]}") @dataclass class UNetSpatioTemporalConditionOutput(BaseOutput): """ The output of [`UNetSpatioTemporalConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNetDragSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. addition_time_embed_dim: (`int`, defaults to 256): Dimension to to encode the additional time ids. projection_class_embeddings_input_dim (`int`, defaults to 768): The dimension of the projection of encoded `added_time_ids`. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): The number of attention heads. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 8, out_channels: int = 4, down_block_types: Tuple[str] = ( "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal", ), up_block_types: Tuple[str] = ( "UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", ), block_out_channels: Tuple[int] = (320, 640, 1280, 1280), addition_time_embed_dim: int = 256, projection_class_embeddings_input_dim: int = 768, layers_per_block: Union[int, Tuple[int]] = 2, cross_attention_dim: Union[int, Tuple[int]] = 1024, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), num_frames: int = 25, num_drags: int = 10, cond_dropout_prob: float = 0.1, pos_embed_dim: int = 64, drag_token_cross_attn: bool = True, use_modulate: bool = True, drag_embedder_out_channels = (256, 320, 320), cross_attn_with_ref: bool = True, double_batch: bool = False, ): super().__init__() self.sample_size = sample_size self.cond_dropout_prob = cond_dropout_prob self.drag_token_cross_attn = drag_token_cross_attn self.pos_embed_dim = pos_embed_dim self.use_modulate = use_modulate self.cross_attn_with_ref = cross_attn_with_ref self.double_batch = double_batch flow_channels = 6 * num_drags # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=3, padding=1, ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( True, down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=1e-5, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], resnet_act_fn="silu", flow_channels=flow_channels, pos_embed_dim=pos_embed_dim, use_modulate=use_modulate, drag_token_cross_attn=drag_token_cross_attn, drag_embedder_out_channels=drag_embedder_out_channels, num_max_drags=num_drags, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlockSpatioTemporal( block_out_channels[-1], temb_channels=blocks_time_embed_dim, transformer_layers_per_block=transformer_layers_per_block[-1], cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( True, up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=1e-5, resolution_idx=i, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], resnet_act_fn="silu", flow_channels=flow_channels, pos_embed_dim=pos_embed_dim, use_modulate=use_modulate, drag_token_cross_attn=drag_token_cross_attn, drag_embedder_out_channels=drag_embedder_out_channels, num_max_drags=num_drags, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=3, padding=1, ) self.num_drags = num_drags self.pos_embedding = {res: torch.tensor(get_2d_sincos_pos_embed(self.pos_embed_dim, res)) for res in [32, 16, 8, 4, 2]} self.pos_embedding_prepared = False @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors( name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor], ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) def _convert_drag_to_concatting_image(self, drags: torch.Tensor, current_resolution: int) -> torch.Tensor: batch_size, num_frames, num_points, _ = drags.shape num_channels = 6 concatting_image = -torch.ones( batch_size, num_frames, num_channels * num_points, current_resolution, current_resolution ).to(drags) not_all_zeros = drags.any(dim=-1).repeat_interleave(num_channels, dim=-1)[..., None, None] y_grid, x_grid = torch.meshgrid(torch.arange(current_resolution), torch.arange(current_resolution), indexing='ij') y_grid = y_grid.to(drags)[None, None, None] # (1, 1, 1, res, res) x_grid = x_grid.to(drags)[None, None, None] # (1, 1, 1, res, res) x0 = (drags[..., 0] * current_resolution - 0.5).round().clip(0, current_resolution - 1) x_src = (drags[..., 0] * current_resolution - x0)[..., None, None] # (batch, num_frames, num_points, 1, 1) x0 = x0[..., None, None] # (batch, num_frames, num_points, 1, 1) x0 = torch.stack([ x0, x0, torch.zeros_like(x0) - 1, torch.zeros_like(x0) - 1, torch.zeros_like(x0) - 1, torch.zeros_like(x0) - 1, ], dim=3).view(batch_size, num_frames, num_channels * num_points, 1, 1) y0 = (drags[..., 1] * current_resolution - 0.5).round().clip(0, current_resolution - 1) y_src = (drags[..., 1] * current_resolution - y0)[..., None, None] # (batch, num_frames, num_points, 1, 1) y0 = y0[..., None, None] # (batch, num_frames, num_points, 1, 1) y0 = torch.stack([ y0, y0, torch.zeros_like(y0) - 1, torch.zeros_like(y0) - 1, torch.zeros_like(y0) - 1, torch.zeros_like(y0) - 1, ], dim=3).view(batch_size, num_frames, num_channels * num_points, 1, 1) x1 = (drags[..., 2] * current_resolution - 0.5).round().clip(0, current_resolution - 1) x_tgt = (drags[..., 2] * current_resolution - x1)[..., None, None] # (batch, num_frames, num_points, 1, 1) x1 = x1[..., None, None] # (batch, num_frames, num_points, 1, 1) x1 = torch.stack([ torch.zeros_like(x1) - 1, torch.zeros_like(x1) - 1, x1, x1, torch.zeros_like(x1) - 1, torch.zeros_like(x1) - 1 ], dim=3).view(batch_size, num_frames, num_channels * num_points, 1, 1) y1 = (drags[..., 3] * current_resolution - 0.5).round().clip(0, current_resolution - 1) y_tgt = (drags[..., 3] * current_resolution - y1)[..., None, None] # (batch, num_frames, num_points, 1, 1) y1 = y1[..., None, None] # (batch, num_frames, num_points, 1, 1) y1 = torch.stack([ torch.zeros_like(y1) - 1, torch.zeros_like(y1) - 1, y1, y1, torch.zeros_like(y1) - 1, torch.zeros_like(y1) - 1 ], dim=3).view(batch_size, num_frames, num_channels * num_points, 1, 1) drags_final = drags[:, -1:, :, :].expand_as(drags) x_final = (drags_final[..., 2] * current_resolution - 0.5).round().clip(0, current_resolution - 1) x_final_tgt = (drags_final[..., 2] * current_resolution - x_final)[..., None, None] # (batch, num_frames, num_points, 1, 1) x_final = x_final[..., None, None] # (batch, num_frames, num_points, 1, 1) x_final = torch.stack([ torch.zeros_like(x_final) - 1, torch.zeros_like(x_final) - 1, torch.zeros_like(x_final) - 1, torch.zeros_like(x_final) - 1, x_final, x_final ], dim=3).view(batch_size, num_frames, num_channels * num_points, 1, 1) y_final = (drags_final[..., 3] * current_resolution - 0.5).round().clip(0, current_resolution - 1) y_final_tgt = (drags_final[..., 3] * current_resolution - y_final)[..., None, None] # (batch, num_frames, num_points, 1, 1) y_final = y_final[..., None, None] # (batch, num_frames, num_points, 1, 1) y_final = torch.stack([ torch.zeros_like(y_final) - 1, torch.zeros_like(y_final) - 1, torch.zeros_like(y_final) - 1, torch.zeros_like(y_final) - 1, y_final, y_final ], dim=3).view(batch_size, num_frames, num_channels * num_points, 1, 1) value_image = torch.stack([ x_src, y_src, x_tgt, y_tgt, x_final_tgt, y_final_tgt ], dim=3).view(batch_size, num_frames, num_channels * num_points, 1, 1) value_image = value_image.expand_as(concatting_image) start_mask = (x_grid == x0) & (y_grid == y0) & not_all_zeros end_mask = (x_grid == x1) & (y_grid == y1) & not_all_zeros final_mask = (x_grid == x_final) & (y_grid == y_final) & not_all_zeros concatting_image[start_mask] = value_image[start_mask] concatting_image[end_mask] = value_image[end_mask] concatting_image[final_mask] = value_image[final_mask] return concatting_image def zero_init(self): for block in self.down_blocks: if hasattr(block, "flow_convs"): for flow_conv in block.flow_convs: try: nn.init.constant_(flow_conv.conv_out.weight, 0) nn.init.constant_(flow_conv.conv_out.bias, 0) except: nn.init.constant_(flow_conv.weight, 0) for block in self.up_blocks: if hasattr(block, "flow_convs"): for flow_conv in block.flow_convs: try: nn.init.constant_(flow_conv.conv_out.weight, 0) nn.init.constant_(flow_conv.conv_out.bias, 0) except: nn.init.constant_(flow_conv.weight, 0) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], image_latents: torch.FloatTensor, encoder_hidden_states: torch.Tensor, added_time_ids: torch.Tensor, drags: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None, ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: r""" The [`UNetSpatioTemporalConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. image_latents (`torch.FloatTensor`): The clean conditioning tensor of the first frame of the image with shape `(batch, num_channels, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. added_time_ids: (`torch.FloatTensor`): The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal embeddings and added to the time embeddings. drags (`torch.Tensor`): The drags tensor with shape `(batch, num_frames, num_points, 4)`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ batch_size, num_frames = sample.shape[:2] if not self.pos_embedding_prepared: for res in self.pos_embedding: self.pos_embedding[res] = self.pos_embedding[res].to(drags) self.pos_embedding_prepared = True # 0. prepare for cfg drag_drop_ids = None if (self.training and self.cond_dropout_prob > 0) or force_drop_ids is not None: if force_drop_ids is None: drag_drop_ids = torch.rand(batch_size, device=sample.device) < self.cond_dropout_prob else: drag_drop_ids = (force_drop_ids == 1) drags = drags * ~drag_drop_ids[:, None, None, None] sample = torch.cat([sample, image_latents[:, None].repeat(1, num_frames, 1, 1, 1)], dim=2) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(batch_size) if self.cross_attn_with_ref and self.double_batch: sample_ref = image_latents[:, None].repeat(1, num_frames, 2, 1, 1) sample_ref[:, :, :4] = sample_ref[:, :, :4] * 0.18215 sample = torch.cat([sample_ref, sample], dim=0) drags = torch.cat([torch.zeros_like(drags), drags], dim=0) encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0) timesteps = torch.cat([timesteps, timesteps], dim=0) batch_size *= 2 drag_encodings = {res: self._convert_drag_to_concatting_image(drags, res) for res in [32, 16, 8]} t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb) # Flatten the batch and frames dimensions # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] sample = sample.flatten(0, 1) # Repeat the embeddings num_video_frames times # emb: [batch, channels] -> [batch * frames, channels] emb = emb.repeat_interleave(num_frames, dim=0) # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) # 2. pre-process sample = self.conv_in(sample) image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: flow = drag_encodings[sample.shape[-1]] sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, flow=flow.flatten(0, 1), drag_original=drags.flatten(0, 1), ) else: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, image_only_indicator=image_only_indicator, ) down_block_res_samples += res_samples # 4. mid sample = self.mid_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, ) # 5. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: flow = drag_encodings[sample.shape[-1]] sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, flow=flow.flatten(0, 1), drag_original=drags.flatten(0, 1), ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, image_only_indicator=image_only_indicator, ) # 6. post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # 7. Reshape back to original shape sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) if self.cross_attn_with_ref and self.double_batch: sample = sample[batch_size // 2:] return sample if __name__ == "__main__": puppet_master = UNetDragSpatioTemporalConditionModel(num_drags=5)