import torch import torch.nn as nn from typing import Union, Tuple, Optional, Dict, Any from diffusers.utils import is_torch_version from diffusers.models.resnet import ( Downsample2D, SpatioTemporalResBlock, Upsample2D ) from diffusers.models.unet_3d_blocks import ( DownBlockSpatioTemporal, UpBlockSpatioTemporal, ) from cameractrl.models.transformer_temporal import TransformerSpatioTemporalModelPoseCond def get_down_block( down_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_downsample: bool, num_attention_heads: int, cross_attention_dim: Optional[int] = None, transformer_layers_per_block: int = 1, **kwargs, ) -> Union[ "DownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporalPoseCond", ]: if down_block_type == "DownBlockSpatioTemporal": # added for SDV return DownBlockSpatioTemporal( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, ) elif down_block_type == "CrossAttnDownBlockSpatioTemporalPoseCond": # added for SDV if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") return CrossAttnDownBlockSpatioTemporalPoseCond( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, add_downsample=add_downsample, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type: str, num_layers: int, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, add_upsample: bool, num_attention_heads: int, resolution_idx: Optional[int] = None, cross_attention_dim: Optional[int] = None, transformer_layers_per_block: int = 1, **kwargs, ) -> Union[ "UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporalPoseCond", ]: if up_block_type == "UpBlockSpatioTemporal": # added for SDV return UpBlockSpatioTemporal( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, add_upsample=add_upsample, ) elif up_block_type == "CrossAttnUpBlockSpatioTemporalPoseCond": # added for SDV if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") return CrossAttnUpBlockSpatioTemporalPoseCond( in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, add_upsample=add_upsample, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, resolution_idx=resolution_idx, ) raise ValueError(f"{up_block_type} does not exist.") class CrossAttnDownBlockSpatioTemporalPoseCond(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: int = 1, cross_attention_dim: int = 1280, add_downsample: bool = True, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=1e-6, ) ) attentions.append( TransformerSpatioTemporalModelPoseCond( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=1, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, # [bs * frame, c, h, w] temb: Optional[torch.FloatTensor] = None, # [bs * frame, c] encoder_hidden_states: Optional[torch.FloatTensor] = None, # [bs * frame, 1, c] image_only_indicator: Optional[torch.Tensor] = None, # [bs, frame] pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w] ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () blocks = list(zip(self.resnets, self.attentions)) for resnet, attn in blocks: if self.training and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] else: hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) # [bs * frame, c, h, w] hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, pose_feature=pose_feature, return_dict=False, )[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class UNetMidBlockSpatioTemporalPoseCond(nn.Module): def __init__( self, in_channels: int, temb_channels: int, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: int = 1, cross_attention_dim: int = 1280, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers # there is always at least one resnet resnets = [ SpatioTemporalResBlock( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=1e-5, ) ] attentions = [] for i in range(num_layers): attentions.append( TransformerSpatioTemporalModelPoseCond( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, ) ) resnets.append( SpatioTemporalResBlock( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=1e-5, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, image_only_indicator: Optional[torch.Tensor] = None, pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w] ) -> torch.FloatTensor: hidden_states = self.resnets[0]( hidden_states, temb, image_only_indicator=image_only_indicator, ) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, **ckpt_kwargs, ) else: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, pose_feature=pose_feature, return_dict=False, )[0] hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) return hidden_states class CrossAttnUpBlockSpatioTemporalPoseCond(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, resolution_idx: Optional[int] = None, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, num_attention_heads: int = 1, cross_attention_dim: int = 1280, add_upsample: bool = True, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, ) ) attentions.append( TransformerSpatioTemporalModelPoseCond( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, image_only_indicator: Optional[torch.Tensor] = None, pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w] ) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] else: hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, pose_feature=pose_feature, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states