maskgct / models /svc /diffusion /diffusion_trainer.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from diffusers import DDPMScheduler
from models.svc.base import SVCTrainer
from modules.encoder.condition_encoder import ConditionEncoder
from .diffusion_wrapper import DiffusionWrapper
class DiffusionTrainer(SVCTrainer):
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
implements ``_build_model`` and ``_forward_step`` methods.
"""
def __init__(self, args=None, cfg=None):
SVCTrainer.__init__(self, args, cfg)
# Only for SVC tasks using diffusion
self.noise_scheduler = DDPMScheduler(
**self.cfg.model.diffusion.scheduler_settings,
)
self.diffusion_timesteps = (
self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
)
### Following are methods only for diffusion models ###
def _build_model(self):
r"""Build the model for training. This function is called in ``__init__`` function."""
# TODO: sort out the config
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
self.acoustic_mapper = DiffusionWrapper(self.cfg)
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
num_of_params_encoder = self.count_parameters(self.condition_encoder)
num_of_params_am = self.count_parameters(self.acoustic_mapper)
num_of_params = num_of_params_encoder + num_of_params_am
log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
)
self.logger.info(log)
return model
def count_parameters(self, model):
model_param = 0.0
if isinstance(model, dict):
for key, value in model.items():
model_param += sum(p.numel() for p in model[key].parameters())
else:
model_param = sum(p.numel() for p in model.parameters())
return model_param
def _check_nan(self, batch, loss, y_pred, y_gt):
if torch.any(torch.isnan(loss)):
for k, v in batch.items():
self.logger.info(k)
self.logger.info(v)
super()._check_nan(loss, y_pred, y_gt)
def _forward_step(self, batch):
r"""Forward step for training and inference. This function is called
in ``_train_step`` & ``_test_step`` function.
"""
device = self.accelerator.device
if self.online_features_extraction:
# On-the-fly features extraction
batch = self._extract_svc_features(batch)
# To debug
# for k, v in batch.items():
# print(k, v.shape, v)
# exit()
mel_input = batch["mel"]
noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
batch_size = mel_input.size(0)
timesteps = torch.randint(
0,
self.diffusion_timesteps,
(batch_size,),
device=device,
dtype=torch.long,
)
noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
conditioner = self.condition_encoder(batch)
y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
self._check_nan(batch, loss, y_pred, noise)
return loss