lmzjms's picture
Upload 591 files
9206300
raw
history blame
5.81 kB
from tasks.tts.fs2 import FastSpeech2Task
from modules.syntaspeech.multi_window_disc import Discriminator
from utils.hparams import hparams
from torch import nn
import torch
import torch.optim
import torch.utils.data
import utils
class FastSpeech2AdvTask(FastSpeech2Task):
def build_model(self):
self.build_tts_model()
if hparams['load_ckpt'] != '':
self.load_ckpt(hparams['load_ckpt'], strict=False)
utils.print_arch(self.model, 'Generator')
self.build_disc_model()
if not hasattr(self, 'gen_params'):
self.gen_params = list(self.model.parameters())
return self.model
def build_disc_model(self):
disc_win_num = hparams['disc_win_num']
h = hparams['mel_disc_hidden_size']
self.mel_disc = Discriminator(
time_lengths=[32, 64, 128][:disc_win_num],
freq_length=80, hidden_size=h, kernel=(3, 3)
)
self.disc_params = list(self.mel_disc.parameters())
utils.print_arch(self.mel_disc, model_name='Mel Disc')
def _training_step(self, sample, batch_idx, optimizer_idx):
log_outputs = {}
loss_weights = {}
disc_start = hparams['mel_gan'] and self.global_step >= hparams["disc_start_steps"] and \
hparams['lambda_mel_adv'] > 0
if optimizer_idx == 0:
#######################
# Generator #
#######################
log_outputs, model_out = self.run_model(self.model, sample, return_output=True)
self.model_out = {k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)}
if disc_start:
self.disc_cond = disc_cond = self.model_out['decoder_inp'].detach() \
if hparams['use_cond_disc'] else None
if hparams['mel_loss_no_noise']:
self.add_mel_loss(model_out['mel_out_nonoise'], sample['mels'], log_outputs)
mel_p = model_out['mel_out']
if hasattr(self.model, 'out2mel'):
mel_p = self.model.out2mel(mel_p)
o_ = self.mel_disc(mel_p, disc_cond)
p_, pc_ = o_['y'], o_['y_c']
if p_ is not None:
log_outputs['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size()))
loss_weights['a'] = hparams['lambda_mel_adv']
if pc_ is not None:
log_outputs['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size()))
loss_weights['ac'] = hparams['lambda_mel_adv']
else:
#######################
# Discriminator #
#######################
if disc_start and self.global_step % hparams['disc_interval'] == 0:
if hparams['rerun_gen']:
with torch.no_grad():
_, model_out = self.run_model(self.model, sample, return_output=True)
else:
model_out = self.model_out
mel_g = sample['mels']
mel_p = model_out['mel_out']
if hasattr(self.model, 'out2mel'):
mel_p = self.model.out2mel(mel_p)
o = self.mel_disc(mel_g, self.disc_cond)
p, pc = o['y'], o['y_c']
o_ = self.mel_disc(mel_p, self.disc_cond)
p_, pc_ = o_['y'], o_['y_c']
if p_ is not None:
log_outputs["r"] = self.mse_loss_fn(p, p.new_ones(p.size()))
log_outputs["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size()))
if pc_ is not None:
log_outputs["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size()))
log_outputs["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size()))
if len(log_outputs) == 0:
return None
total_loss = sum([loss_weights.get(k, 1) * v for k, v in log_outputs.items()])
log_outputs['bs'] = sample['mels'].shape[0]
return total_loss, log_outputs
def configure_optimizers(self):
if not hasattr(self, 'gen_params'):
self.gen_params = list(self.model.parameters())
optimizer_gen = torch.optim.AdamW(
self.gen_params,
lr=hparams['lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
weight_decay=hparams['weight_decay'])
optimizer_disc = torch.optim.AdamW(
self.disc_params,
lr=hparams['disc_lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
**hparams["discriminator_optimizer_params"]) if len(self.disc_params) > 0 else None
self.scheduler = self.build_scheduler({'gen': optimizer_gen, 'disc': optimizer_disc})
return [optimizer_gen, optimizer_disc]
def build_scheduler(self, optimizer):
return {
"gen": super().build_scheduler(optimizer['gen']),
"disc": torch.optim.lr_scheduler.StepLR(
optimizer=optimizer["disc"],
**hparams["discriminator_scheduler_params"]) if optimizer["disc"] is not None else None,
}
def on_before_optimization(self, opt_idx):
if opt_idx == 0:
nn.utils.clip_grad_norm_(self.gen_params, hparams['generator_grad_norm'])
else:
nn.utils.clip_grad_norm_(self.disc_params, hparams["discriminator_grad_norm"])
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
if optimizer_idx == 0:
self.scheduler['gen'].step(self.global_step)
else:
self.scheduler['disc'].step(max(self.global_step - hparams["disc_start_steps"], 1))