Spaces:
Build error
Build error
from tasks.tts.fs2 import FastSpeech2Task | |
from modules.syntaspeech.multi_window_disc import Discriminator | |
from utils.hparams import hparams | |
from torch import nn | |
import torch | |
import torch.optim | |
import torch.utils.data | |
import utils | |
class FastSpeech2AdvTask(FastSpeech2Task): | |
def build_model(self): | |
self.build_tts_model() | |
if hparams['load_ckpt'] != '': | |
self.load_ckpt(hparams['load_ckpt'], strict=False) | |
utils.print_arch(self.model, 'Generator') | |
self.build_disc_model() | |
if not hasattr(self, 'gen_params'): | |
self.gen_params = list(self.model.parameters()) | |
return self.model | |
def build_disc_model(self): | |
disc_win_num = hparams['disc_win_num'] | |
h = hparams['mel_disc_hidden_size'] | |
self.mel_disc = Discriminator( | |
time_lengths=[32, 64, 128][:disc_win_num], | |
freq_length=80, hidden_size=h, kernel=(3, 3) | |
) | |
self.disc_params = list(self.mel_disc.parameters()) | |
utils.print_arch(self.mel_disc, model_name='Mel Disc') | |
def _training_step(self, sample, batch_idx, optimizer_idx): | |
log_outputs = {} | |
loss_weights = {} | |
disc_start = hparams['mel_gan'] and self.global_step >= hparams["disc_start_steps"] and \ | |
hparams['lambda_mel_adv'] > 0 | |
if optimizer_idx == 0: | |
####################### | |
# Generator # | |
####################### | |
log_outputs, model_out = self.run_model(self.model, sample, return_output=True) | |
self.model_out = {k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)} | |
if disc_start: | |
self.disc_cond = disc_cond = self.model_out['decoder_inp'].detach() \ | |
if hparams['use_cond_disc'] else None | |
if hparams['mel_loss_no_noise']: | |
self.add_mel_loss(model_out['mel_out_nonoise'], sample['mels'], log_outputs) | |
mel_p = model_out['mel_out'] | |
if hasattr(self.model, 'out2mel'): | |
mel_p = self.model.out2mel(mel_p) | |
o_ = self.mel_disc(mel_p, disc_cond) | |
p_, pc_ = o_['y'], o_['y_c'] | |
if p_ is not None: | |
log_outputs['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size())) | |
loss_weights['a'] = hparams['lambda_mel_adv'] | |
if pc_ is not None: | |
log_outputs['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size())) | |
loss_weights['ac'] = hparams['lambda_mel_adv'] | |
else: | |
####################### | |
# Discriminator # | |
####################### | |
if disc_start and self.global_step % hparams['disc_interval'] == 0: | |
if hparams['rerun_gen']: | |
with torch.no_grad(): | |
_, model_out = self.run_model(self.model, sample, return_output=True) | |
else: | |
model_out = self.model_out | |
mel_g = sample['mels'] | |
mel_p = model_out['mel_out'] | |
if hasattr(self.model, 'out2mel'): | |
mel_p = self.model.out2mel(mel_p) | |
o = self.mel_disc(mel_g, self.disc_cond) | |
p, pc = o['y'], o['y_c'] | |
o_ = self.mel_disc(mel_p, self.disc_cond) | |
p_, pc_ = o_['y'], o_['y_c'] | |
if p_ is not None: | |
log_outputs["r"] = self.mse_loss_fn(p, p.new_ones(p.size())) | |
log_outputs["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size())) | |
if pc_ is not None: | |
log_outputs["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size())) | |
log_outputs["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size())) | |
if len(log_outputs) == 0: | |
return None | |
total_loss = sum([loss_weights.get(k, 1) * v for k, v in log_outputs.items()]) | |
log_outputs['bs'] = sample['mels'].shape[0] | |
return total_loss, log_outputs | |
def configure_optimizers(self): | |
if not hasattr(self, 'gen_params'): | |
self.gen_params = list(self.model.parameters()) | |
optimizer_gen = torch.optim.AdamW( | |
self.gen_params, | |
lr=hparams['lr'], | |
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
weight_decay=hparams['weight_decay']) | |
optimizer_disc = torch.optim.AdamW( | |
self.disc_params, | |
lr=hparams['disc_lr'], | |
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
**hparams["discriminator_optimizer_params"]) if len(self.disc_params) > 0 else None | |
self.scheduler = self.build_scheduler({'gen': optimizer_gen, 'disc': optimizer_disc}) | |
return [optimizer_gen, optimizer_disc] | |
def build_scheduler(self, optimizer): | |
return { | |
"gen": super().build_scheduler(optimizer['gen']), | |
"disc": torch.optim.lr_scheduler.StepLR( | |
optimizer=optimizer["disc"], | |
**hparams["discriminator_scheduler_params"]) if optimizer["disc"] is not None else None, | |
} | |
def on_before_optimization(self, opt_idx): | |
if opt_idx == 0: | |
nn.utils.clip_grad_norm_(self.gen_params, hparams['generator_grad_norm']) | |
else: | |
nn.utils.clip_grad_norm_(self.disc_params, hparams["discriminator_grad_norm"]) | |
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): | |
if optimizer_idx == 0: | |
self.scheduler['gen'].step(self.global_step) | |
else: | |
self.scheduler['disc'].step(max(self.global_step - hparams["disc_start_steps"], 1)) | |