import torch from torch import nn from tasks.tts.ps_adv import PortaSpeechAdvTask, FastSpeechTask from text_to_speech.utils.commons.hparams import hparams from text_to_speech.utils.nn.seq_utils import group_hidden_by_segs class PortaSpeechAdvMLMTask(PortaSpeechAdvTask): def build_scheduler(self, optimizer): return [ FastSpeechTask.build_scheduler(self, optimizer[0]), # Generator Scheduler torch.optim.lr_scheduler.StepLR(optimizer=optimizer[1], # Discriminator Scheduler **hparams["discriminator_scheduler_params"]), ] def on_before_optimization(self, opt_idx): if opt_idx in [0, 2]: nn.utils.clip_grad_norm_(self.dp_params, hparams['clip_grad_norm']) if self.use_bert: nn.utils.clip_grad_norm_(self.bert_params, hparams['clip_grad_norm']) nn.utils.clip_grad_norm_(self.gen_params_except_bert_and_dp, hparams['clip_grad_norm']) else: nn.utils.clip_grad_norm_(self.gen_params_except_dp, hparams['clip_grad_norm']) else: nn.utils.clip_grad_norm_(self.disc_params, hparams["clip_grad_norm"]) def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): if self.scheduler is not None: self.scheduler[0].step(self.global_step // hparams['accumulate_grad_batches']) self.scheduler[1].step(self.global_step // hparams['accumulate_grad_batches']) def _training_step(self, sample, batch_idx, optimizer_idx): loss_output = {} loss_weights = {} disc_start = self.global_step >= hparams["disc_start_steps"] and hparams['lambda_mel_adv'] > 0 if optimizer_idx == 0: ####################### # Generator # ####################### loss_output, model_out = self.run_model(sample, infer=False) self.model_out_gt = self.model_out = \ {k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)} if disc_start: mel_p = model_out['mel_out'] if hasattr(self.model, 'out2mel'): mel_p = self.model.out2mel(mel_p) o_ = self.mel_disc(mel_p) p_, pc_ = o_['y'], o_['y_c'] if p_ is not None: loss_output['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size())) loss_weights['a'] = hparams['lambda_mel_adv'] if pc_ is not None: loss_output['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size())) loss_weights['ac'] = hparams['lambda_mel_adv'] else: return None loss_output2, model_out2 = self.run_contrastive_learning(sample) loss_output.update(loss_output2) model_out.update(model_out2) elif optimizer_idx == 1: ####################### # Discriminator # ####################### if disc_start and self.global_step % hparams['disc_interval'] == 0: model_out = self.model_out_gt mel_g = sample['mels'] mel_p = model_out['mel_out'] o = self.mel_disc(mel_g) p, pc = o['y'], o['y_c'] o_ = self.mel_disc(mel_p) p_, pc_ = o_['y'], o_['y_c'] if p_ is not None: loss_output["r"] = self.mse_loss_fn(p, p.new_ones(p.size())) loss_output["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size())) if pc_ is not None: loss_output["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size())) loss_output["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size())) total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) loss_output['batch_size'] = sample['txt_tokens'].size()[0] return total_loss, loss_output def run_contrastive_learning(self, sample): losses = {} outputs = {} bert = self.model.encoder.bert.bert bert_for_mlm = self.model.encoder.bert pooler = self.model.encoder.pooler sim = self.model.encoder.sim tokenizer = self.model.encoder.tokenizer ph_encoder = self.model.encoder if hparams['lambda_cl'] > 0: if hparams.get("cl_version", "v1") == "v1": cl_feats = sample['cl_feats'] bs, _, t = cl_feats['cl_input_ids'].shape cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) pooler_output = pooler(cl_attention_mask, cl_output) pooler_output = pooler_output.reshape([bs, 2, -1]) z1, z2 = pooler_output[:,0], pooler_output[:,1] cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) labels = torch.arange(cos_sim.size(0)).long().to(z1.device) ce_fn = nn.CrossEntropyLoss() cl_loss = ce_fn(cos_sim, labels) losses['cl_v'] = cl_loss.detach() losses['cl'] = cl_loss * hparams['lambda_cl'] elif hparams['cl_version'] == "v2": # use the output of ph encoder as sentence embedding cl_feats = sample['cl_feats'] bs, _, t = cl_feats['cl_input_ids'].shape cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) txt_tokens = sample['txt_tokens'] bert_feats = sample['bert_feats'] src_nonpadding = (txt_tokens > 0).float()[:, :, None] ph_encoder_out1 = ph_encoder(txt_tokens, bert_feats=bert_feats, ph2word=sample['ph2word']) * src_nonpadding ph_encoder_out2 = ph_encoder(txt_tokens, bert_feats=bert_feats, ph2word=sample['ph2word']) * src_nonpadding # word_encoding1 = group_hidden_by_segs(ph_encoder_out1, sample['ph2word'], sample['ph2word'].max().item()) # word_encoding2 = group_hidden_by_segs(ph_encoder_out2, sample['ph2word'], sample['ph2word'].max().item()) z1 = ((ph_encoder_out1 * src_nonpadding).sum(1) / src_nonpadding.sum(1)) z2 = ((ph_encoder_out2 * src_nonpadding).sum(1) / src_nonpadding.sum(1)) cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) labels = torch.arange(cos_sim.size(0)).long().to(z1.device) ce_fn = nn.CrossEntropyLoss() cl_loss = ce_fn(cos_sim, labels) losses['cl_v'] = cl_loss.detach() losses['cl'] = cl_loss * hparams['lambda_cl'] elif hparams['cl_version'] == "v3": # use the word-level contrastive learning cl_feats = sample['cl_feats'] bs, _, t = cl_feats['cl_input_ids'].shape cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) cl_output = cl_output.last_hidden_state.reshape([-1, 768]) # [bs*2,t_w,768] ==> [bs*2*t_w, 768] cl_word_out = cl_output[cl_attention_mask.reshape([-1]).bool()] # [num_word*2, 768] cl_word_out = cl_word_out.view([-1, 2, 768]) z1_total, z2_total = cl_word_out[:,0], cl_word_out[:,1] # [num_word, 768] ce_fn = nn.CrossEntropyLoss() start_idx = 0 lengths = cl_attention_mask.sum(-1) cl_loss_accu = 0 for i in range(bs): length = lengths[i] z1 = z1_total[start_idx:start_idx + length] z2 = z2_total[start_idx:start_idx + length] start_idx += length cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) labels = torch.arange(cos_sim.size(0)).long().to(z1.device) cl_loss_accu += ce_fn(cos_sim, labels) * length cl_loss = cl_loss_accu / lengths.sum() losses['cl_v'] = cl_loss.detach() losses['cl'] = cl_loss * hparams['lambda_cl'] elif hparams['cl_version'] == "v4": # with Wiki dataset cl_feats = sample['cl_feats'] bs, _, t = cl_feats['cl_input_ids'].shape cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) pooler_output = pooler(cl_attention_mask, cl_output) pooler_output = pooler_output.reshape([bs, 2, -1]) z1, z2 = pooler_output[:,0], pooler_output[:,1] cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) labels = torch.arange(cos_sim.size(0)).long().to(z1.device) ce_fn = nn.CrossEntropyLoss() cl_loss = ce_fn(cos_sim, labels) losses['cl_v'] = cl_loss.detach() losses['cl'] = cl_loss * hparams['lambda_cl'] elif hparams['cl_version'] == "v5": # with NLI dataset cl_feats = sample['cl_feats'] cl_input_ids = cl_feats['sent0']['cl_input_ids'] cl_attention_mask = cl_feats['sent0']['cl_attention_mask'] cl_token_type_ids = cl_feats['sent0']['cl_token_type_ids'] cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) z1 = pooler_output_sent0 = pooler(cl_attention_mask, cl_output) cl_input_ids = cl_feats['sent1']['cl_input_ids'] cl_attention_mask = cl_feats['sent1']['cl_attention_mask'] cl_token_type_ids = cl_feats['sent1']['cl_token_type_ids'] cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) z2 = pooler_output_sent1 = pooler(cl_attention_mask, cl_output) cl_input_ids = cl_feats['hard_neg']['cl_input_ids'] cl_attention_mask = cl_feats['hard_neg']['cl_attention_mask'] cl_token_type_ids = cl_feats['hard_neg']['cl_token_type_ids'] cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) z3 = pooler_output_neg = pooler(cl_attention_mask, cl_output) cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) z1_z3_cos = sim(z1.unsqueeze(1), z3.unsqueeze(0)) cos_sim = torch.cat([cos_sim, z1_z3_cos], 1) # [n_sent, n_sent * 2] labels = torch.arange(cos_sim.size(0)).long().to(cos_sim.device) # [n_sent, ] ce_fn = nn.CrossEntropyLoss() cl_loss = ce_fn(cos_sim, labels) losses['cl_v'] = cl_loss.detach() losses['cl'] = cl_loss * hparams['lambda_cl'] else: raise NotImplementedError() if hparams['lambda_mlm'] > 0: cl_feats = sample['cl_feats'] mlm_input_ids = cl_feats['mlm_input_ids'] bs, t = mlm_input_ids.shape mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1))) mlm_labels = cl_feats['mlm_labels'] mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) mlm_attention_mask = cl_feats['mlm_attention_mask'] prediction_scores = bert_for_mlm(mlm_input_ids, mlm_attention_mask).logits ce_fn = nn.CrossEntropyLoss(reduction="none") mlm_loss = ce_fn(prediction_scores.view(-1, tokenizer.vocab_size), mlm_labels.view(-1)) mlm_loss = mlm_loss[mlm_labels.view(-1)>=0].mean() losses['mlm'] = mlm_loss * hparams['lambda_mlm'] losses['mlm_v'] = mlm_loss.detach() return losses, outputs