maskgct / models /svc /diffusion /diffusion_wrapper.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from modules.diffusion import BiDilConv
from modules.encoder.position_encoder import PositionEncoder
class DiffusionWrapper(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.diff_cfg = cfg.model.diffusion
self.diff_encoder = PositionEncoder(
d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding,
d_out=self.diff_cfg.bidilconv.base_channel,
d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer,
activation_function=self.diff_cfg.step_encoder.activation,
n_layer=self.diff_cfg.step_encoder.num_layer,
max_period=self.diff_cfg.step_encoder.max_period,
)
# FIXME: Only support BiDilConv now for debug
if self.diff_cfg.model_type.lower() == "bidilconv":
self.neural_network = BiDilConv(
input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv
)
else:
raise ValueError(
f"Unsupported diffusion model type: {self.diff_cfg.model_type}"
)
def forward(self, x, t, c):
"""
Args:
x: [N, T, mel_band] of mel spectrogram
t: Diffusion time step with shape of [N]
c: [N, T, conditioner_size] of conditioner
Returns:
[N, T, mel_band] of mel spectrogram
"""
assert (
x.size()[:-1] == c.size()[:-1]
), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size())
assert x.size(0) == t.size(
0
), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size())
assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim())
N, T, mel_band = x.size()
x = x.transpose(1, 2).contiguous() # [N, mel_band, T]
c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T]
t = self.diff_encoder(t).contiguous() # [N, base_channel]
h = self.neural_network(x, t, c)
h = h.transpose(1, 2).contiguous() # [N, T, mel_band]
assert h.size() == (
N,
T,
mel_band,
), "h mismatch with input x, got \n h: {} \n x: {}".format(
h.size(), (N, T, mel_band)
)
return h