File size: 3,786 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from diffusers import DDPMScheduler

from models.svc.base import SVCTrainer
from modules.encoder.condition_encoder import ConditionEncoder
from .diffusion_wrapper import DiffusionWrapper


class DiffusionTrainer(SVCTrainer):
    r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
    implements ``_build_model`` and ``_forward_step`` methods.
    """

    def __init__(self, args=None, cfg=None):
        SVCTrainer.__init__(self, args, cfg)

        # Only for SVC tasks using diffusion
        self.noise_scheduler = DDPMScheduler(
            **self.cfg.model.diffusion.scheduler_settings,
        )
        self.diffusion_timesteps = (
            self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
        )

    ### Following are methods only for diffusion models ###
    def _build_model(self):
        r"""Build the model for training. This function is called in ``__init__`` function."""

        # TODO: sort out the config
        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
        self.acoustic_mapper = DiffusionWrapper(self.cfg)
        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])

        num_of_params_encoder = self.count_parameters(self.condition_encoder)
        num_of_params_am = self.count_parameters(self.acoustic_mapper)
        num_of_params = num_of_params_encoder + num_of_params_am
        log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
            num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
        )
        self.logger.info(log)

        return model

    def count_parameters(self, model):
        model_param = 0.0
        if isinstance(model, dict):
            for key, value in model.items():
                model_param += sum(p.numel() for p in model[key].parameters())
        else:
            model_param = sum(p.numel() for p in model.parameters())
        return model_param

    def _check_nan(self, batch, loss, y_pred, y_gt):
        if torch.any(torch.isnan(loss)):
            for k, v in batch.items():
                self.logger.info(k)
                self.logger.info(v)

            super()._check_nan(loss, y_pred, y_gt)

    def _forward_step(self, batch):
        r"""Forward step for training and inference. This function is called
        in ``_train_step`` & ``_test_step`` function.
        """
        device = self.accelerator.device

        if self.online_features_extraction:
            # On-the-fly features extraction
            batch = self._extract_svc_features(batch)

            # To debug
            # for k, v in batch.items():
            #     print(k, v.shape, v)
            # exit()

        mel_input = batch["mel"]
        noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
        batch_size = mel_input.size(0)
        timesteps = torch.randint(
            0,
            self.diffusion_timesteps,
            (batch_size,),
            device=device,
            dtype=torch.long,
        )

        noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
        conditioner = self.condition_encoder(batch)

        y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)

        loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
        self._check_nan(batch, loss, y_pred, noise)

        return loss