Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from diffusers import DDPMScheduler | |
from models.svc.base import SVCTrainer | |
from modules.encoder.condition_encoder import ConditionEncoder | |
from .diffusion_wrapper import DiffusionWrapper | |
class DiffusionTrainer(SVCTrainer): | |
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and | |
implements ``_build_model`` and ``_forward_step`` methods. | |
""" | |
def __init__(self, args=None, cfg=None): | |
SVCTrainer.__init__(self, args, cfg) | |
# Only for SVC tasks using diffusion | |
self.noise_scheduler = DDPMScheduler( | |
**self.cfg.model.diffusion.scheduler_settings, | |
) | |
self.diffusion_timesteps = ( | |
self.cfg.model.diffusion.scheduler_settings.num_train_timesteps | |
) | |
### Following are methods only for diffusion models ### | |
def _build_model(self): | |
r"""Build the model for training. This function is called in ``__init__`` function.""" | |
# TODO: sort out the config | |
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min | |
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max | |
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) | |
self.acoustic_mapper = DiffusionWrapper(self.cfg) | |
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) | |
num_of_params_encoder = self.count_parameters(self.condition_encoder) | |
num_of_params_am = self.count_parameters(self.acoustic_mapper) | |
num_of_params = num_of_params_encoder + num_of_params_am | |
log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format( | |
num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6 | |
) | |
self.logger.info(log) | |
return model | |
def count_parameters(self, model): | |
model_param = 0.0 | |
if isinstance(model, dict): | |
for key, value in model.items(): | |
model_param += sum(p.numel() for p in model[key].parameters()) | |
else: | |
model_param = sum(p.numel() for p in model.parameters()) | |
return model_param | |
def _check_nan(self, batch, loss, y_pred, y_gt): | |
if torch.any(torch.isnan(loss)): | |
for k, v in batch.items(): | |
self.logger.info(k) | |
self.logger.info(v) | |
super()._check_nan(loss, y_pred, y_gt) | |
def _forward_step(self, batch): | |
r"""Forward step for training and inference. This function is called | |
in ``_train_step`` & ``_test_step`` function. | |
""" | |
device = self.accelerator.device | |
if self.online_features_extraction: | |
# On-the-fly features extraction | |
batch = self._extract_svc_features(batch) | |
# To debug | |
# for k, v in batch.items(): | |
# print(k, v.shape, v) | |
# exit() | |
mel_input = batch["mel"] | |
noise = torch.randn_like(mel_input, device=device, dtype=torch.float32) | |
batch_size = mel_input.size(0) | |
timesteps = torch.randint( | |
0, | |
self.diffusion_timesteps, | |
(batch_size,), | |
device=device, | |
dtype=torch.long, | |
) | |
noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps) | |
conditioner = self.condition_encoder(batch) | |
y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner) | |
loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"]) | |
self._check_nan(batch, loss, y_pred, noise) | |
return loss | |