from typing import Optional, Tuple, Union from einops import rearrange import torch import torch.nn as nn from diffusers.models.attention_processor import Attention from diffusers.models.resnet import ResnetBlock2D from diffusers.models.upsampling import Upsample2D from diffusers.models.downsampling import Downsample2D class TemporalConvBlock(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 """ def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1): super().__init__() out_dim = out_dim or in_dim self.in_dim = in_dim self.out_dim = out_dim spa_pad = int((spa_stride-1)*0.5) temp_pad = 0 self.temp_pad = temp_pad if down_sample: self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad)) ) elif up_sample: self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad)) ) else: self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)) ) self.conv2 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), ) self.conv3 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), ) self.conv4 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), ) # zero out the last layer params,so the conv block is identity nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) self.down_sample = down_sample self.up_sample = up_sample def forward(self, hidden_states): identity = hidden_states if self.down_sample: identity = identity[:,:,::2] elif self.up_sample: hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2) hidden_states_new[:, :, 0::2] = hidden_states hidden_states_new[:, :, 1::2] = hidden_states identity = hidden_states_new del hidden_states_new if self.down_sample or self.up_sample: hidden_states = self.conv1(hidden_states) else: hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) hidden_states = self.conv1(hidden_states) if self.up_sample: hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2) hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) hidden_states = self.conv2(hidden_states) hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) hidden_states = self.conv3(hidden_states) hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) hidden_states = self.conv4(hidden_states) hidden_states = identity + hidden_states return hidden_states class DownEncoderBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, add_temp_downsample=False, downsample_padding=1, ): super().__init__() resnets = [] temp_convs = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvBlock( out_channels, out_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_temp_downsample: self.temp_convs_down = TemporalConvBlock( out_channels, out_channels, dropout=0.1, down_sample=True, spa_stride=3 ) self.add_temp_downsample = add_temp_downsample if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def _set_partial_grad(self): for temp_conv in self.temp_convs: temp_conv.requires_grad_(True) if self.downsamplers: for down_layer in self.downsamplers: down_layer.requires_grad_(True) def forward(self, hidden_states): bz = hidden_states.shape[0] for resnet, temp_conv in zip(self.resnets, self.temp_convs): hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') hidden_states = resnet(hidden_states, temb=None) hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) hidden_states = temp_conv(hidden_states) if self.add_temp_downsample: hidden_states = self.temp_convs_down(hidden_states) if self.downsamplers is not None: hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') for upsampler in self.downsamplers: hidden_states = upsampler(hidden_states) hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) return hidden_states class UpDecoderBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, add_temp_upsample=False, temb_channels=None, ): super().__init__() self.add_upsample = add_upsample resnets = [] temp_convs = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvBlock( out_channels, out_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.add_temp_upsample = add_temp_upsample if add_temp_upsample: self.temp_conv_up = TemporalConvBlock( out_channels, out_channels, dropout=0.1, up_sample=True, spa_stride=3 ) if self.add_upsample: # self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)]) self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def _set_partial_grad(self): for temp_conv in self.temp_convs: temp_conv.requires_grad_(True) if self.add_upsample: self.upsamplers.requires_grad_(True) def forward(self, hidden_states): bz = hidden_states.shape[0] for resnet, temp_conv in zip(self.resnets, self.temp_convs): hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') hidden_states = resnet(hidden_states, temb=None) hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) hidden_states = temp_conv(hidden_states) if self.add_temp_upsample: hidden_states = self.temp_conv_up(hidden_states) if self.upsamplers is not None: hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) return hidden_states class UNetMidBlock3DConv(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim=1, output_scale_factor=1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] temp_convs = [ TemporalConvBlock( in_channels, in_channels, dropout=0.1, ) ] attentions = [] if attention_head_dim is None: attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvBlock( in_channels, in_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) def _set_partial_grad(self): for temp_conv in self.temp_convs: temp_conv.requires_grad_(True) def forward( self, hidden_states, ): bz = hidden_states.shape[0] hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') hidden_states = self.resnets[0](hidden_states, temb=None) hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) hidden_states = self.temp_convs[0](hidden_states) hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') for attn, resnet, temp_conv in zip( self.attentions, self.resnets[1:], self.temp_convs[1:] ): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, temb=None) hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) hidden_states = temp_conv(hidden_states) return hidden_states