pyramid-flow / video_vae /modeling_resnet.py
multimodalart's picture
Upload 33 files
f0533a5 verified
raw
history blame
No virus
27.4 kB
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import SpatialNorm
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from diffusers.models.normalization import AdaGroupNorm
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
class CausalResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
linear_cls = nn.Linear
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = CausalGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = CausalGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = CausalConv3d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = CausalConv3d(
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor = None,
is_init_image=True,
temporal_chunk=False,
) -> torch.FloatTensor:
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class ResnetBlock2D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
linear_cls = nn.Linear
conv_cls = nn.Conv3d
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor = None,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class CausalDownsample2x(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = True,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size=3,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
stride = (1, 2, 2)
self.name = name
if use_conv:
conv = CausalConv3d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool3d(kernel_size=stride, stride=stride)
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return hidden_states
class Downsample2D(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = True,
out_channels: Optional[int] = None,
padding: int = 0,
name: str = "conv",
kernel_size=3,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = (1, 2, 2)
self.name = name
conv_cls = nn.Conv3d
if use_conv:
conv = conv_cls(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1, 1, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class TemporalDownsample2x(nn.Module):
"""A Temporal downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 0,
kernel_size=3,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = (2, 1, 1)
conv_cls = nn.Conv3d
if use_conv:
conv = conv_cls(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:
raise NotImplementedError("Not implemented for temporal downsample without")
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
if hidden_states.shape[2] == 1:
# image
pad = (1, 1, 1, 1, 1, 1)
else:
# video
pad = (1, 1, 1, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states
class CausalTemporalDownsample2x(nn.Module):
"""A Temporal downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
kernel_size=3,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
stride = (2, 1, 1)
conv_cls = nn.Conv3d
if use_conv:
conv = CausalConv3d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
)
else:
raise NotImplementedError("Not implemented for temporal downsample without")
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return hidden_states
class Upsample2D(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = None,
padding=1,
bias=True,
interpolate=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv3d
conv = None
if interpolate:
raise NotImplementedError("Not implemented for spatial upsample with interpolate")
else:
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels * 4, kernel_size=kernel_size, padding=padding, bias=bias)
self.conv = conv
self.conv.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self,
hidden_states: torch.FloatTensor,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
return hidden_states
class CausalUpsample2x(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = 3,
bias=True,
interpolate=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.name = name
self.interpolate = interpolate
conv = None
if interpolate:
raise NotImplementedError("Not implemented for spatial upsample with interpolate")
else:
conv = CausalConv3d(self.channels, self.out_channels * 4, kernel_size=kernel_size, stride=1, bias=bias)
self.conv = conv
def forward(
self,
hidden_states: torch.FloatTensor,
is_init_image=True, temporal_chunk=False,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
return hidden_states
class TemporalUpsample2x(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = True,
out_channels: Optional[int] = None,
kernel_size: Optional[int] = None,
padding=1,
bias=True,
interpolate=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.interpolate = interpolate
conv_cls = nn.Conv3d
conv = None
if interpolate:
raise NotImplementedError("Not implemented for spatial upsample with interpolate")
else:
# depth to space operator
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels * 2, kernel_size=kernel_size, padding=padding, bias=bias)
self.conv = conv
def forward(
self,
hidden_states: torch.FloatTensor,
is_image: bool = False,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
t = hidden_states.shape[2]
hidden_states = self.conv(hidden_states)
hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (p t) h w', p=2)
if t == 1 and is_image:
hidden_states = hidden_states[:, :, 1:]
return hidden_states
class CausalTemporalUpsample2x(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = True,
out_channels: Optional[int] = None,
kernel_size: Optional[int] = 3,
bias=True,
interpolate=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.interpolate = interpolate
conv = None
if interpolate:
raise NotImplementedError("Not implemented for spatial upsample with interpolate")
else:
# depth to space operator
conv = CausalConv3d(self.channels, self.out_channels * 2, kernel_size=kernel_size, stride=1, bias=bias)
self.conv = conv
def forward(
self,
hidden_states: torch.FloatTensor,
is_init_image=True, temporal_chunk=False,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
t = hidden_states.shape[2]
hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (t p) h w', p=2)
if is_init_image:
hidden_states = hidden_states[:, :, 1:]
return hidden_states