|
from typing import Optional, Tuple, Union |
|
from einops import rearrange |
|
|
|
import torch |
|
import torch.nn as nn |
|
from diffusers.models.attention_processor import Attention |
|
from diffusers.models.resnet import ResnetBlock2D |
|
from diffusers.models.upsampling import Upsample2D |
|
from diffusers.models.downsampling import Downsample2D |
|
|
|
|
|
class TemporalConvBlock(nn.Module): |
|
""" |
|
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: |
|
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 |
|
""" |
|
|
|
def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1): |
|
super().__init__() |
|
out_dim = out_dim or in_dim |
|
self.in_dim = in_dim |
|
self.out_dim = out_dim |
|
spa_pad = int((spa_stride-1)*0.5) |
|
temp_pad = 0 |
|
self.temp_pad = temp_pad |
|
|
|
if down_sample: |
|
self.conv1 = nn.Sequential( |
|
nn.GroupNorm(32, in_dim), |
|
nn.SiLU(), |
|
nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad)) |
|
) |
|
elif up_sample: |
|
self.conv1 = nn.Sequential( |
|
nn.GroupNorm(32, in_dim), |
|
nn.SiLU(), |
|
nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad)) |
|
) |
|
else: |
|
self.conv1 = nn.Sequential( |
|
nn.GroupNorm(32, in_dim), |
|
nn.SiLU(), |
|
nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)) |
|
) |
|
self.conv2 = nn.Sequential( |
|
nn.GroupNorm(32, out_dim), |
|
nn.SiLU(), |
|
nn.Dropout(dropout), |
|
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), |
|
) |
|
self.conv3 = nn.Sequential( |
|
nn.GroupNorm(32, out_dim), |
|
nn.SiLU(), |
|
nn.Dropout(dropout), |
|
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), |
|
) |
|
self.conv4 = nn.Sequential( |
|
nn.GroupNorm(32, out_dim), |
|
nn.SiLU(), |
|
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), |
|
) |
|
|
|
|
|
nn.init.zeros_(self.conv4[-1].weight) |
|
nn.init.zeros_(self.conv4[-1].bias) |
|
|
|
self.down_sample = down_sample |
|
self.up_sample = up_sample |
|
|
|
|
|
def forward(self, hidden_states): |
|
identity = hidden_states |
|
|
|
if self.down_sample: |
|
identity = identity[:,:,::2] |
|
elif self.up_sample: |
|
hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2) |
|
hidden_states_new[:, :, 0::2] = hidden_states |
|
hidden_states_new[:, :, 1::2] = hidden_states |
|
identity = hidden_states_new |
|
del hidden_states_new |
|
|
|
if self.down_sample or self.up_sample: |
|
hidden_states = self.conv1(hidden_states) |
|
else: |
|
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) |
|
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) |
|
hidden_states = self.conv1(hidden_states) |
|
|
|
|
|
if self.up_sample: |
|
hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2) |
|
|
|
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) |
|
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) |
|
hidden_states = self.conv2(hidden_states) |
|
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) |
|
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) |
|
hidden_states = self.conv3(hidden_states) |
|
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) |
|
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) |
|
hidden_states = self.conv4(hidden_states) |
|
|
|
hidden_states = identity + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
class DownEncoderBlock3D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
output_scale_factor=1.0, |
|
add_downsample=True, |
|
add_temp_downsample=False, |
|
downsample_padding=1, |
|
): |
|
super().__init__() |
|
resnets = [] |
|
temp_convs = [] |
|
|
|
for i in range(num_layers): |
|
in_channels = in_channels if i == 0 else out_channels |
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=None, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
temp_convs.append( |
|
TemporalConvBlock( |
|
out_channels, |
|
out_channels, |
|
dropout=0.1, |
|
) |
|
) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
self.temp_convs = nn.ModuleList(temp_convs) |
|
|
|
if add_temp_downsample: |
|
self.temp_convs_down = TemporalConvBlock( |
|
out_channels, |
|
out_channels, |
|
dropout=0.1, |
|
down_sample=True, |
|
spa_stride=3 |
|
) |
|
self.add_temp_downsample = add_temp_downsample |
|
|
|
if add_downsample: |
|
self.downsamplers = nn.ModuleList( |
|
[ |
|
Downsample2D( |
|
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" |
|
) |
|
] |
|
) |
|
else: |
|
self.downsamplers = None |
|
|
|
def _set_partial_grad(self): |
|
for temp_conv in self.temp_convs: |
|
temp_conv.requires_grad_(True) |
|
if self.downsamplers: |
|
for down_layer in self.downsamplers: |
|
down_layer.requires_grad_(True) |
|
|
|
def forward(self, hidden_states): |
|
bz = hidden_states.shape[0] |
|
|
|
for resnet, temp_conv in zip(self.resnets, self.temp_convs): |
|
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') |
|
hidden_states = resnet(hidden_states, temb=None) |
|
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) |
|
hidden_states = temp_conv(hidden_states) |
|
if self.add_temp_downsample: |
|
hidden_states = self.temp_convs_down(hidden_states) |
|
|
|
if self.downsamplers is not None: |
|
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') |
|
for upsampler in self.downsamplers: |
|
hidden_states = upsampler(hidden_states) |
|
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) |
|
return hidden_states |
|
|
|
|
|
class UpDecoderBlock3D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
output_scale_factor=1.0, |
|
add_upsample=True, |
|
add_temp_upsample=False, |
|
temb_channels=None, |
|
): |
|
super().__init__() |
|
self.add_upsample = add_upsample |
|
|
|
resnets = [] |
|
temp_convs = [] |
|
|
|
for i in range(num_layers): |
|
input_channels = in_channels if i == 0 else out_channels |
|
|
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=input_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
temp_convs.append( |
|
TemporalConvBlock( |
|
out_channels, |
|
out_channels, |
|
dropout=0.1, |
|
) |
|
) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
self.temp_convs = nn.ModuleList(temp_convs) |
|
|
|
self.add_temp_upsample = add_temp_upsample |
|
if add_temp_upsample: |
|
self.temp_conv_up = TemporalConvBlock( |
|
out_channels, |
|
out_channels, |
|
dropout=0.1, |
|
up_sample=True, |
|
spa_stride=3 |
|
) |
|
|
|
|
|
if self.add_upsample: |
|
|
|
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) |
|
else: |
|
self.upsamplers = None |
|
|
|
def _set_partial_grad(self): |
|
for temp_conv in self.temp_convs: |
|
temp_conv.requires_grad_(True) |
|
if self.add_upsample: |
|
self.upsamplers.requires_grad_(True) |
|
|
|
def forward(self, hidden_states): |
|
bz = hidden_states.shape[0] |
|
|
|
for resnet, temp_conv in zip(self.resnets, self.temp_convs): |
|
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') |
|
hidden_states = resnet(hidden_states, temb=None) |
|
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) |
|
hidden_states = temp_conv(hidden_states) |
|
if self.add_temp_upsample: |
|
hidden_states = self.temp_conv_up(hidden_states) |
|
|
|
if self.upsamplers is not None: |
|
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') |
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states) |
|
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) |
|
return hidden_states |
|
|
|
|
|
class UNetMidBlock3DConv(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
temb_channels: int, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
add_attention: bool = True, |
|
attention_head_dim=1, |
|
output_scale_factor=1.0, |
|
): |
|
super().__init__() |
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
|
self.add_attention = add_attention |
|
|
|
|
|
resnets = [ |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=in_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
] |
|
temp_convs = [ |
|
TemporalConvBlock( |
|
in_channels, |
|
in_channels, |
|
dropout=0.1, |
|
) |
|
] |
|
attentions = [] |
|
|
|
if attention_head_dim is None: |
|
attention_head_dim = in_channels |
|
|
|
for _ in range(num_layers): |
|
if self.add_attention: |
|
attentions.append( |
|
Attention( |
|
in_channels, |
|
heads=in_channels // attention_head_dim, |
|
dim_head=attention_head_dim, |
|
rescale_output_factor=output_scale_factor, |
|
eps=resnet_eps, |
|
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, |
|
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, |
|
residual_connection=True, |
|
bias=True, |
|
upcast_softmax=True, |
|
_from_deprecated_attn_block=True, |
|
) |
|
) |
|
else: |
|
attentions.append(None) |
|
|
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=in_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
|
|
temp_convs.append( |
|
TemporalConvBlock( |
|
in_channels, |
|
in_channels, |
|
dropout=0.1, |
|
) |
|
) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
self.temp_convs = nn.ModuleList(temp_convs) |
|
self.attentions = nn.ModuleList(attentions) |
|
|
|
def _set_partial_grad(self): |
|
for temp_conv in self.temp_convs: |
|
temp_conv.requires_grad_(True) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
): |
|
bz = hidden_states.shape[0] |
|
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') |
|
|
|
hidden_states = self.resnets[0](hidden_states, temb=None) |
|
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) |
|
hidden_states = self.temp_convs[0](hidden_states) |
|
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') |
|
|
|
for attn, resnet, temp_conv in zip( |
|
self.attentions, self.resnets[1:], self.temp_convs[1:] |
|
): |
|
hidden_states = attn(hidden_states) |
|
hidden_states = resnet(hidden_states, temb=None) |
|
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) |
|
hidden_states = temp_conv(hidden_states) |
|
return hidden_states |
|
|