File size: 1,968 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch

from models.svc.base import SVCTrainer
from modules.encoder.condition_encoder import ConditionEncoder
from models.svc.transformer.transformer import Transformer
from models.svc.transformer.conformer import Conformer
from utils.ssim import SSIM


class TransformerTrainer(SVCTrainer):
    def __init__(self, args, cfg):
        SVCTrainer.__init__(self, args, cfg)
        self.ssim_loss = SSIM()

    def _build_model(self):
        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
        if self.cfg.model.transformer.type == "transformer":
            self.acoustic_mapper = Transformer(self.cfg.model.transformer)
        elif self.cfg.model.transformer.type == "conformer":
            self.acoustic_mapper = Conformer(self.cfg.model.transformer)
        else:
            raise NotImplementedError
        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
        return model

    def _forward_step(self, batch):
        total_loss = 0
        device = self.accelerator.device
        mel = batch["mel"]
        mask = batch["mask"]

        condition = self.condition_encoder(batch)
        mel_pred = self.acoustic_mapper(condition, mask)

        l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum(
            batch["mask"]
        )
        self._check_nan(l1_loss, mel_pred, mel)
        total_loss += l1_loss
        ssim_loss = self.ssim_loss(mel_pred, mel)
        ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"])
        self._check_nan(ssim_loss, mel_pred, mel)
        total_loss += ssim_loss

        return total_loss