|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from models.svc.base import SVCTrainer |
|
from modules.encoder.condition_encoder import ConditionEncoder |
|
from models.svc.transformer.transformer import Transformer |
|
from models.svc.transformer.conformer import Conformer |
|
from utils.ssim import SSIM |
|
|
|
|
|
class TransformerTrainer(SVCTrainer): |
|
def __init__(self, args, cfg): |
|
SVCTrainer.__init__(self, args, cfg) |
|
self.ssim_loss = SSIM() |
|
|
|
def _build_model(self): |
|
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min |
|
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max |
|
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) |
|
if self.cfg.model.transformer.type == "transformer": |
|
self.acoustic_mapper = Transformer(self.cfg.model.transformer) |
|
elif self.cfg.model.transformer.type == "conformer": |
|
self.acoustic_mapper = Conformer(self.cfg.model.transformer) |
|
else: |
|
raise NotImplementedError |
|
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) |
|
return model |
|
|
|
def _forward_step(self, batch): |
|
total_loss = 0 |
|
device = self.accelerator.device |
|
mel = batch["mel"] |
|
mask = batch["mask"] |
|
|
|
condition = self.condition_encoder(batch) |
|
mel_pred = self.acoustic_mapper(condition, mask) |
|
|
|
l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum( |
|
batch["mask"] |
|
) |
|
self._check_nan(l1_loss, mel_pred, mel) |
|
total_loss += l1_loss |
|
ssim_loss = self.ssim_loss(mel_pred, mel) |
|
ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"]) |
|
self._check_nan(ssim_loss, mel_pred, mel) |
|
total_loss += ssim_loss |
|
|
|
return total_loss |
|
|