maskgct / models /svc /comosvc /comosvc_trainer.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
12.2 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import os
import json5
from tqdm import tqdm
import json
import shutil
from models.svc.base import SVCTrainer
from modules.encoder.condition_encoder import ConditionEncoder
from models.svc.comosvc.comosvc import ComoSVC
class ComoSVCTrainer(SVCTrainer):
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
implements ``_build_model`` and ``_forward_step`` methods.
"""
def __init__(self, args=None, cfg=None):
SVCTrainer.__init__(self, args, cfg)
self.distill = cfg.model.comosvc.distill
self.skip_diff = True
### Following are methods only for comoSVC models ###
def _load_teacher_model(self, model):
r"""Load teacher model from checkpoint file."""
self.checkpoint_file = self.teacher_model_path
self.logger.info(
"Load teacher acoustic model from {}".format(self.checkpoint_file)
)
raw_dict = torch.load(self.checkpoint_file)
model.load_state_dict(raw_dict)
def _build_model(self):
r"""Build the model for training. This function is called in ``__init__`` function."""
# TODO: sort out the config
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
self.acoustic_mapper = ComoSVC(self.cfg)
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
if self.cfg.model.comosvc.distill:
if not self.args.resume:
# do not load teacher model when resume
self.teacher_model_path = self.cfg.model.teacher_model_path
self._load_teacher_model(model)
# build teacher & target decoder and freeze teacher
self.acoustic_mapper.decoder.init_consistency_training()
self.freeze_net(self.condition_encoder)
self.freeze_net(self.acoustic_mapper.encoder)
self.freeze_net(self.acoustic_mapper.decoder.denoise_fn_pretrained)
self.freeze_net(self.acoustic_mapper.decoder.denoise_fn_ema)
return model
def freeze_net(self, model):
r"""Freeze the model for training."""
for name, param in model.named_parameters():
param.requires_grad = False
def __build_optimizer(self):
r"""Build optimizer for training. This function is called in ``__init__`` function."""
if self.cfg.train.optimizer.lower() == "adamw":
optimizer = torch.optim.AdamW(
params=filter(lambda p: p.requires_grad, self.model.parameters()),
**self.cfg.train.adamw,
)
else:
raise NotImplementedError(
"Not support optimizer: {}".format(self.cfg.train.optimizer)
)
return optimizer
def _forward_step(self, batch):
r"""Forward step for training and inference. This function is called
in ``_train_step`` & ``_test_step`` function.
"""
loss = {}
mask = batch["mask"]
mel_input = batch["mel"]
cond = self.condition_encoder(batch)
if self.distill:
cond = cond.detach()
self.skip_diff = True if self.step < self.cfg.train.fast_steps else False
ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss(
mask, cond, mel_input, skip_diff=self.skip_diff
)
if self.distill:
loss["distil_loss"] = diff_loss
else:
loss["ssim_loss_encoder"] = ssim_loss
loss["prior_loss_encoder"] = prior_loss
loss["diffusion_loss_decoder"] = diff_loss
return loss
def _train_epoch(self):
r"""Training epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
self.model.train()
epoch_sum_loss: float = 0.0
epoch_step: int = 0
for batch in tqdm(
self.train_dataloader,
desc=f"Training Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
# Do training step and BP
with self.accelerator.accumulate(self.model):
loss = self._train_step(batch)
total_loss = 0
for k, v in loss.items():
total_loss += v
self.accelerator.backward(total_loss)
enc_grad_norm = torch.nn.utils.clip_grad_norm_(
self.acoustic_mapper.encoder.parameters(), max_norm=1
)
dec_grad_norm = torch.nn.utils.clip_grad_norm_(
self.acoustic_mapper.decoder.parameters(), max_norm=1
)
self.optimizer.step()
self.optimizer.zero_grad()
self.batch_count += 1
# Update info for each step
# TODO: step means BP counts or batch counts?
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
epoch_sum_loss += total_loss
log_info = {}
for k, v in loss.items():
key = "Step/Train Loss/{}".format(k)
log_info[key] = v
log_info["Step/Learning Rate"] = self.optimizer.param_groups[0]["lr"]
self.accelerator.log(
log_info,
step=self.step,
)
self.step += 1
epoch_step += 1
self.accelerator.wait_for_everyone()
return (
epoch_sum_loss
/ len(self.train_dataloader)
* self.cfg.train.gradient_accumulation_step,
loss,
)
def train_loop(self):
r"""Training loop. The public entry of training process."""
# Wait everyone to prepare before we move on
self.accelerator.wait_for_everyone()
# dump config file
if self.accelerator.is_main_process:
self.__dump_cfg(self.config_save_path)
self.model.train()
self.optimizer.zero_grad()
# Wait to ensure good to go
self.accelerator.wait_for_everyone()
while self.epoch < self.max_epoch:
self.logger.info("\n")
self.logger.info("-" * 32)
self.logger.info("Epoch {}: ".format(self.epoch))
### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
### It's inconvenient for the model with multiple losses
# Do training & validating epoch
train_loss, loss = self._train_epoch()
self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
for k, v in loss.items():
self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v))
valid_loss = self._valid_epoch()
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
self.accelerator.log(
{"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
step=self.epoch,
)
self.accelerator.wait_for_everyone()
# TODO: what is scheduler?
self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
# Check if hit save_checkpoint_stride and run_eval
run_eval = False
if self.accelerator.is_main_process:
save_checkpoint = False
hit_dix = []
for i, num in enumerate(self.save_checkpoint_stride):
if self.epoch % num == 0:
save_checkpoint = True
hit_dix.append(i)
run_eval |= self.run_eval[i]
self.accelerator.wait_for_everyone()
if (
self.accelerator.is_main_process
and save_checkpoint
and (self.distill or not self.skip_diff)
):
path = os.path.join(
self.checkpoint_dir,
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
self.epoch, self.step, train_loss
),
)
self.tmp_checkpoint_save_path = path
self.accelerator.save_state(path)
print(f"save checkpoint in {path}")
json.dump(
self.checkpoints_path,
open(os.path.join(path, "ckpts.json"), "w"),
ensure_ascii=False,
indent=4,
)
self._save_auxiliary_states()
# Remove old checkpoints
to_remove = []
for idx in hit_dix:
self.checkpoints_path[idx].append(path)
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
# Search conflicts
total = set()
for i in self.checkpoints_path:
total |= set(i)
do_remove = set()
for idx, path in to_remove[::-1]:
if path in total:
self.checkpoints_path[idx].insert(0, path)
else:
do_remove.add(path)
# Remove old checkpoints
for path in do_remove:
shutil.rmtree(path, ignore_errors=True)
self.logger.debug(f"Remove old checkpoint: {path}")
self.accelerator.wait_for_everyone()
if run_eval:
# TODO: run evaluation
pass
# Update info for each epoch
self.epoch += 1
# Finish training and save final checkpoint
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
self.accelerator.save_state(
os.path.join(
self.checkpoint_dir,
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
self.epoch, self.step, valid_loss
),
)
)
self._save_auxiliary_states()
self.accelerator.end_training()
@torch.inference_mode()
def _valid_epoch(self):
r"""Testing epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
self.model.eval()
epoch_sum_loss = 0.0
for batch in tqdm(
self.valid_dataloader,
desc=f"Validating Epoch {self.epoch}",
unit="batch",
colour="GREEN",
leave=False,
dynamic_ncols=True,
smoothing=0.04,
disable=not self.accelerator.is_main_process,
):
batch_loss = self._valid_step(batch)
for k, v in batch_loss.items():
epoch_sum_loss += v
self.accelerator.wait_for_everyone()
return epoch_sum_loss / len(self.valid_dataloader)
@staticmethod
def __count_parameters(model):
model_param = 0.0
if isinstance(model, dict):
for key, value in model.items():
model_param += sum(p.numel() for p in model[key].parameters())
else:
model_param = sum(p.numel() for p in model.parameters())
return model_param
def __dump_cfg(self, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
json5.dump(
self.cfg,
open(path, "w"),
indent=4,
sort_keys=True,
ensure_ascii=False,
quote_keys=True,
)