maskgct / models /svc /vits /vits_trainer.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
25.6 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm
from pathlib import Path
import shutil
import accelerate
# from models.svc.base import SVCTrainer
from models.svc.base.svc_dataset import SVCOfflineCollator, SVCOfflineDataset
from models.svc.vits.vits import *
from models.svc.base import SVCTrainer
from utils.mel import mel_spectrogram_torch
import json
from models.vocoders.gan.discriminator.mpd import (
MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
)
class VitsSVCTrainer(SVCTrainer):
def __init__(self, args, cfg):
self.args = args
self.cfg = cfg
SVCTrainer.__init__(self, args, cfg)
def _accelerator_prepare(self):
(
self.train_dataloader,
self.valid_dataloader,
) = self.accelerator.prepare(
self.train_dataloader,
self.valid_dataloader,
)
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key] = self.accelerator.prepare(self.model[key])
else:
self.model = self.accelerator.prepare(self.model)
if isinstance(self.optimizer, dict):
for key in self.optimizer.keys():
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
else:
self.optimizer = self.accelerator.prepare(self.optimizer)
if isinstance(self.scheduler, dict):
for key in self.scheduler.keys():
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
else:
self.scheduler = self.accelerator.prepare(self.scheduler)
def _load_model(
self,
checkpoint_dir: str = None,
checkpoint_path: str = None,
resume_type: str = "",
):
r"""Load model from checkpoint. If checkpoint_path is None, it will
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
None, it will load the checkpoint specified by checkpoint_path. **Only use this
method after** ``accelerator.prepare()``.
"""
if checkpoint_path is None:
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
checkpoint_path = ls[0]
self.logger.info("Resume from {}...".format(checkpoint_path))
if resume_type in ["resume", ""]:
# Load all the things, including model weights, optimizer, scheduler, and random states.
self.accelerator.load_state(input_dir=checkpoint_path)
# set epoch and step
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
elif resume_type == "finetune":
# Load only the model weights
accelerate.load_checkpoint_and_dispatch(
self.accelerator.unwrap_model(self.model["generator"]),
os.path.join(checkpoint_path, "pytorch_model.bin"),
)
accelerate.load_checkpoint_and_dispatch(
self.accelerator.unwrap_model(self.model["discriminator"]),
os.path.join(checkpoint_path, "pytorch_model.bin"),
)
self.logger.info("Load model weights for finetune...")
else:
raise ValueError("Resume_type must be `resume` or `finetune`.")
return checkpoint_path
def _build_model(self):
net_g = SynthesizerTrn(
self.cfg.preprocess.n_fft // 2 + 1,
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
# directly use cfg
self.cfg,
)
net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm)
model = {"generator": net_g, "discriminator": net_d}
return model
def _build_dataset(self):
return SVCOfflineDataset, SVCOfflineCollator
def _build_optimizer(self):
optimizer_g = torch.optim.AdamW(
self.model["generator"].parameters(),
self.cfg.train.learning_rate,
betas=self.cfg.train.AdamW.betas,
eps=self.cfg.train.AdamW.eps,
)
optimizer_d = torch.optim.AdamW(
self.model["discriminator"].parameters(),
self.cfg.train.learning_rate,
betas=self.cfg.train.AdamW.betas,
eps=self.cfg.train.AdamW.eps,
)
optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
return optimizer
def _build_scheduler(self):
scheduler_g = ExponentialLR(
self.optimizer["optimizer_g"],
gamma=self.cfg.train.lr_decay,
last_epoch=self.epoch - 1,
)
scheduler_d = ExponentialLR(
self.optimizer["optimizer_d"],
gamma=self.cfg.train.lr_decay,
last_epoch=self.epoch - 1,
)
scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
return scheduler
def _build_criterion(self):
class GeneratorLoss(nn.Module):
def __init__(self, cfg):
super(GeneratorLoss, self).__init__()
self.cfg = cfg
self.l1_loss = nn.L1Loss()
def generator_loss(self, disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
def feature_loss(self, fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
def forward(
self,
outputs_g,
outputs_d,
y_mel,
y_hat_mel,
):
loss_g = {}
# mel loss
loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
loss_g["loss_mel"] = loss_mel
# kl loss
loss_kl = (
self.kl_loss(
outputs_g["z_p"],
outputs_g["logs_q"],
outputs_g["m_p"],
outputs_g["logs_p"],
outputs_g["z_mask"],
)
* self.cfg.train.c_kl
)
loss_g["loss_kl"] = loss_kl
# feature loss
loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
loss_g["loss_fm"] = loss_fm
# gan loss
loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
loss_g["loss_gen"] = loss_gen
loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen
return loss_g
class DiscriminatorLoss(nn.Module):
def __init__(self, cfg):
super(DiscriminatorLoss, self).__init__()
self.cfg = cfg
self.l1Loss = torch.nn.L1Loss(reduction="mean")
def __call__(self, disc_real_outputs, disc_generated_outputs):
loss_d = {}
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
loss_d["loss_disc_all"] = loss
return loss_d
criterion = {
"generator": GeneratorLoss(self.cfg),
"discriminator": DiscriminatorLoss(self.cfg),
}
return criterion
# Keep legacy unchanged
def write_summary(
self,
losses,
stats,
images={},
audios={},
audio_sampling_rate=24000,
tag="train",
):
for key, value in losses.items():
self.sw.add_scalar(tag + "/" + key, value, self.step)
self.sw.add_scalar(
"learning_rate",
self.optimizer["optimizer_g"].param_groups[0]["lr"],
self.step,
)
if len(images) != 0:
for key, value in images.items():
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
if len(audios) != 0:
for key, value in audios.items():
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
def write_valid_summary(
self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
):
for key, value in losses.items():
self.sw.add_scalar(tag + "/" + key, value, self.step)
if len(images) != 0:
for key, value in images.items():
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
if len(audios) != 0:
for key, value in audios.items():
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
def _get_state_dict(self):
state_dict = {
"generator": self.model["generator"].state_dict(),
"discriminator": self.model["discriminator"].state_dict(),
"optimizer_g": self.optimizer["optimizer_g"].state_dict(),
"optimizer_d": self.optimizer["optimizer_d"].state_dict(),
"scheduler_g": self.scheduler["scheduler_g"].state_dict(),
"scheduler_d": self.scheduler["scheduler_d"].state_dict(),
"step": self.step,
"epoch": self.epoch,
"batch_size": self.cfg.train.batch_size,
}
return state_dict
def get_state_dict(self):
state_dict = {
"generator": self.model["generator"].state_dict(),
"discriminator": self.model["discriminator"].state_dict(),
"optimizer_g": self.optimizer["optimizer_g"].state_dict(),
"optimizer_d": self.optimizer["optimizer_d"].state_dict(),
"scheduler_g": self.scheduler["scheduler_g"].state_dict(),
"scheduler_d": self.scheduler["scheduler_d"].state_dict(),
"step": self.step,
"epoch": self.epoch,
"batch_size": self.cfg.train.batch_size,
}
return state_dict
def load_model(self, checkpoint):
self.step = checkpoint["step"]
self.epoch = checkpoint["epoch"]
self.model["generator"].load_state_dict(checkpoint["generator"])
self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
@torch.inference_mode()
def _valid_step(self, batch):
r"""Testing forward step. Should return average loss of a sample over
one batch. Provoke ``_forward_step`` is recommended except for special case.
See ``_test_epoch`` for usage.
"""
valid_losses = {}
total_loss = 0
valid_stats = {}
# Discriminator
# Generator output
outputs_g = self.model["generator"](batch)
y_mel = slice_segments(
batch["mel"].transpose(1, 2),
outputs_g["ids_slice"],
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
)
y_hat_mel = mel_spectrogram_torch(
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
)
y = slice_segments(
batch["audio"].unsqueeze(1),
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
self.cfg.preprocess.segment_size,
)
# Discriminator output
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
## Discriminator loss
loss_d = self.criterion["discriminator"](
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
)
valid_losses.update(loss_d)
## Generator
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
valid_losses.update(loss_g)
for item in valid_losses:
valid_losses[item] = valid_losses[item].item()
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
return (
total_loss.item(),
valid_losses,
valid_stats,
)
@torch.inference_mode()
def _valid_epoch(self):
r"""Testing epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key].eval()
else:
self.model.eval()
epoch_sum_loss = 0.0
epoch_losses = dict()
for batch in tqdm(
self.valid_dataloader,
desc=f"Validating Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
total_loss, valid_losses, valid_stats = self._valid_step(batch)
epoch_sum_loss += total_loss
if isinstance(valid_losses, dict):
for key, value in valid_losses.items():
if key not in epoch_losses.keys():
epoch_losses[key] = value
else:
epoch_losses[key] += value
epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
for key in epoch_losses.keys():
epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
self.accelerator.wait_for_everyone()
return epoch_sum_loss, epoch_losses
### THIS IS MAIN ENTRY ###
def train_loop(self):
r"""Training loop. The public entry of training process."""
# Wait everyone to prepare before we move on
self.accelerator.wait_for_everyone()
# dump config file
if self.accelerator.is_main_process:
self.__dump_cfg(self.config_save_path)
# self.optimizer.zero_grad()
# Wait to ensure good to go
self.accelerator.wait_for_everyone()
while self.epoch < self.max_epoch:
self.logger.info("\n")
self.logger.info("-" * 32)
self.logger.info("Epoch {}: ".format(self.epoch))
# Do training & validating epoch
train_total_loss, train_losses = self._train_epoch()
if isinstance(train_losses, dict):
for key, loss in train_losses.items():
self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
self.accelerator.log(
{"Epoch/Train {} Loss".format(key): loss},
step=self.epoch,
)
valid_total_loss, valid_losses = self._valid_epoch()
if isinstance(valid_losses, dict):
for key, loss in valid_losses.items():
self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
self.accelerator.log(
{"Epoch/Train {} Loss".format(key): loss},
step=self.epoch,
)
self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
self.accelerator.log(
{
"Epoch/Train Loss": train_total_loss,
"Epoch/Valid Loss": valid_total_loss,
},
step=self.epoch,
)
self.accelerator.wait_for_everyone()
# Check if hit save_checkpoint_stride and run_eval
run_eval = False
if self.accelerator.is_main_process:
save_checkpoint = False
hit_dix = []
for i, num in enumerate(self.save_checkpoint_stride):
if self.epoch % num == 0:
save_checkpoint = True
hit_dix.append(i)
run_eval |= self.run_eval[i]
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process and save_checkpoint:
path = os.path.join(
self.checkpoint_dir,
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
self.epoch, self.step, train_total_loss
),
)
self.tmp_checkpoint_save_path = path
self.accelerator.save_state(path)
json.dump(
self.checkpoints_path,
open(os.path.join(path, "ckpts.json"), "w"),
ensure_ascii=False,
indent=4,
)
self._save_auxiliary_states()
# Remove old checkpoints
to_remove = []
for idx in hit_dix:
self.checkpoints_path[idx].append(path)
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
# Search conflicts
total = set()
for i in self.checkpoints_path:
total |= set(i)
do_remove = set()
for idx, path in to_remove[::-1]:
if path in total:
self.checkpoints_path[idx].insert(0, path)
else:
do_remove.add(path)
# Remove old checkpoints
for path in do_remove:
shutil.rmtree(path, ignore_errors=True)
self.logger.debug(f"Remove old checkpoint: {path}")
self.accelerator.wait_for_everyone()
if run_eval:
# TODO: run evaluation
pass
# Update info for each epoch
self.epoch += 1
# Finish training and save final checkpoint
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
path = os.path.join(
self.checkpoint_dir,
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
self.epoch, self.step, valid_total_loss
),
)
self.tmp_checkpoint_save_path = path
self.accelerator.save_state(
os.path.join(
self.checkpoint_dir,
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
self.epoch, self.step, valid_total_loss
),
)
)
json.dump(
self.checkpoints_path,
open(os.path.join(path, "ckpts.json"), "w"),
ensure_ascii=False,
indent=4,
)
self._save_auxiliary_states()
self.accelerator.end_training()
def _train_step(self, batch):
r"""Forward step for training and inference. This function is called
in ``_train_step`` & ``_test_step`` function.
"""
train_losses = {}
total_loss = 0
training_stats = {}
## Train Discriminator
# Generator output
outputs_g = self.model["generator"](batch)
y_mel = slice_segments(
batch["mel"].transpose(1, 2),
outputs_g["ids_slice"],
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
)
y_hat_mel = mel_spectrogram_torch(
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
)
y = slice_segments(
# [1, 168418] -> [1, 1, 168418]
batch["audio"].unsqueeze(1),
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
self.cfg.preprocess.segment_size,
)
# Discriminator output
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
# Discriminator loss
loss_d = self.criterion["discriminator"](
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
)
train_losses.update(loss_d)
# BP and Grad Updated
self.optimizer["optimizer_d"].zero_grad()
self.accelerator.backward(loss_d["loss_disc_all"])
self.optimizer["optimizer_d"].step()
## Train Generator
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
train_losses.update(loss_g)
# BP and Grad Updated
self.optimizer["optimizer_g"].zero_grad()
self.accelerator.backward(loss_g["loss_gen_all"])
self.optimizer["optimizer_g"].step()
for item in train_losses:
train_losses[item] = train_losses[item].item()
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
return (
total_loss.item(),
train_losses,
training_stats,
)
def _train_epoch(self):
r"""Training epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
epoch_sum_loss: float = 0.0
epoch_losses: dict = {}
epoch_step: int = 0
for batch in tqdm(
self.train_dataloader,
desc=f"Training Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
# Do training step and BP
with self.accelerator.accumulate(self.model):
total_loss, train_losses, training_stats = self._train_step(batch)
self.batch_count += 1
# Update info for each step
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
epoch_sum_loss += total_loss
for key, value in train_losses.items():
if key not in epoch_losses.keys():
epoch_losses[key] = value
else:
epoch_losses[key] += value
self.accelerator.log(
{
"Step/Generator Loss": train_losses["loss_gen_all"],
"Step/Discriminator Loss": train_losses["loss_disc_all"],
"Step/Generator Learning Rate": self.optimizer[
"optimizer_d"
].param_groups[0]["lr"],
"Step/Discriminator Learning Rate": self.optimizer[
"optimizer_g"
].param_groups[0]["lr"],
},
step=self.step,
)
self.step += 1
epoch_step += 1
self.accelerator.wait_for_everyone()
epoch_sum_loss = (
epoch_sum_loss
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
for key in epoch_losses.keys():
epoch_losses[key] = (
epoch_losses[key]
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step
)
return epoch_sum_loss, epoch_losses
def __dump_cfg(self, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
json5.dump(
self.cfg,
open(path, "w"),
indent=4,
sort_keys=True,
ensure_ascii=False,
quote_keys=True,
)