fffiloni's picture
Upload 15 files
cdcfdd8 verified
raw
history blame
15.4 kB
from typing import Optional, Tuple, Union
from einops import rearrange
import torch
import torch.nn as nn
from diffusers.models.attention_processor import Attention
from diffusers.models.resnet import ResnetBlock2D
from diffusers.models.upsampling import Upsample2D
from diffusers.models.downsampling import Downsample2D
class TemporalConvBlock(nn.Module):
"""
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
"""
def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
self.out_dim = out_dim
spa_pad = int((spa_stride-1)*0.5)
temp_pad = 0
self.temp_pad = temp_pad
if down_sample:
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad))
)
elif up_sample:
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad))
)
else:
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad))
)
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
)
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
)
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
)
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
self.down_sample = down_sample
self.up_sample = up_sample
def forward(self, hidden_states):
identity = hidden_states
if self.down_sample:
identity = identity[:,:,::2]
elif self.up_sample:
hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2)
hidden_states_new[:, :, 0::2] = hidden_states
hidden_states_new[:, :, 1::2] = hidden_states
identity = hidden_states_new
del hidden_states_new
if self.down_sample or self.up_sample:
hidden_states = self.conv1(hidden_states)
else:
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
hidden_states = self.conv1(hidden_states)
if self.up_sample:
hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2)
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
hidden_states = self.conv2(hidden_states)
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
hidden_states = self.conv3(hidden_states)
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
hidden_states = self.conv4(hidden_states)
hidden_states = identity + hidden_states
return hidden_states
class DownEncoderBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
add_temp_downsample=False,
downsample_padding=1,
):
super().__init__()
resnets = []
temp_convs = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
TemporalConvBlock(
out_channels,
out_channels,
dropout=0.1,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
if add_temp_downsample:
self.temp_convs_down = TemporalConvBlock(
out_channels,
out_channels,
dropout=0.1,
down_sample=True,
spa_stride=3
)
self.add_temp_downsample = add_temp_downsample
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
def _set_partial_grad(self):
for temp_conv in self.temp_convs:
temp_conv.requires_grad_(True)
if self.downsamplers:
for down_layer in self.downsamplers:
down_layer.requires_grad_(True)
def forward(self, hidden_states):
bz = hidden_states.shape[0]
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
hidden_states = resnet(hidden_states, temb=None)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
hidden_states = temp_conv(hidden_states)
if self.add_temp_downsample:
hidden_states = self.temp_convs_down(hidden_states)
if self.downsamplers is not None:
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
for upsampler in self.downsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
return hidden_states
class UpDecoderBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
add_temp_upsample=False,
temb_channels=None,
):
super().__init__()
self.add_upsample = add_upsample
resnets = []
temp_convs = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
TemporalConvBlock(
out_channels,
out_channels,
dropout=0.1,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
self.add_temp_upsample = add_temp_upsample
if add_temp_upsample:
self.temp_conv_up = TemporalConvBlock(
out_channels,
out_channels,
dropout=0.1,
up_sample=True,
spa_stride=3
)
if self.add_upsample:
# self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)])
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def _set_partial_grad(self):
for temp_conv in self.temp_convs:
temp_conv.requires_grad_(True)
if self.add_upsample:
self.upsamplers.requires_grad_(True)
def forward(self, hidden_states):
bz = hidden_states.shape[0]
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
hidden_states = resnet(hidden_states, temb=None)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
hidden_states = temp_conv(hidden_states)
if self.add_temp_upsample:
hidden_states = self.temp_conv_up(hidden_states)
if self.upsamplers is not None:
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
return hidden_states
class UNetMidBlock3DConv(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
temp_convs = [
TemporalConvBlock(
in_channels,
in_channels,
dropout=0.1,
)
]
attentions = []
if attention_head_dim is None:
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
TemporalConvBlock(
in_channels,
in_channels,
dropout=0.1,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
self.attentions = nn.ModuleList(attentions)
def _set_partial_grad(self):
for temp_conv in self.temp_convs:
temp_conv.requires_grad_(True)
def forward(
self,
hidden_states,
):
bz = hidden_states.shape[0]
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
hidden_states = self.resnets[0](hidden_states, temb=None)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
hidden_states = self.temp_convs[0](hidden_states)
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
for attn, resnet, temp_conv in zip(
self.attentions, self.resnets[1:], self.temp_convs[1:]
):
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb=None)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
hidden_states = temp_conv(hidden_states)
return hidden_states