Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch.nn as nn | |
from modules.general.utils import Conv1d, zero_module | |
from .residual_block import ResidualBlock | |
class BiDilConv(nn.Module): | |
r"""Dilated CNN architecture with residual connections, default diffusion decoder. | |
Args: | |
input_channel: The number of input channels. | |
base_channel: The number of base channels. | |
n_res_block: The number of residual blocks. | |
conv_kernel_size: The kernel size of convolutional layers. | |
dilation_cycle_length: The cycle length of dilation. | |
conditioner_size: The size of conditioner. | |
""" | |
def __init__( | |
self, | |
input_channel, | |
base_channel, | |
n_res_block, | |
conv_kernel_size, | |
dilation_cycle_length, | |
conditioner_size, | |
output_channel: int = -1, | |
): | |
super().__init__() | |
self.input_channel = input_channel | |
self.base_channel = base_channel | |
self.n_res_block = n_res_block | |
self.conv_kernel_size = conv_kernel_size | |
self.dilation_cycle_length = dilation_cycle_length | |
self.conditioner_size = conditioner_size | |
self.output_channel = output_channel if output_channel > 0 else input_channel | |
self.input = nn.Sequential( | |
Conv1d( | |
input_channel, | |
base_channel, | |
1, | |
), | |
nn.ReLU(), | |
) | |
self.residual_blocks = nn.ModuleList( | |
[ | |
ResidualBlock( | |
channels=base_channel, | |
kernel_size=conv_kernel_size, | |
dilation=2 ** (i % dilation_cycle_length), | |
d_context=conditioner_size, | |
) | |
for i in range(n_res_block) | |
] | |
) | |
self.out_proj = nn.Sequential( | |
Conv1d( | |
base_channel, | |
base_channel, | |
1, | |
), | |
nn.ReLU(), | |
zero_module( | |
Conv1d( | |
base_channel, | |
self.output_channel, | |
1, | |
), | |
), | |
) | |
def forward(self, x, y, context=None): | |
""" | |
Args: | |
x: Noisy mel-spectrogram [B x ``n_mel`` x L] | |
y: FILM embeddings with the shape of (B, ``base_channel``) | |
context: Context with the shape of [B x ``d_context`` x L], default to None. | |
""" | |
h = self.input(x) | |
skip = None | |
for i in range(self.n_res_block): | |
h, skip_connection = self.residual_blocks[i](h, y, context) | |
skip = skip_connection if skip is None else skip_connection + skip | |
out = skip / math.sqrt(self.n_res_block) | |
out = self.out_proj(out) | |
return out | |