Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import logging | |
import math | |
from argparse import Namespace | |
from pathlib import Path | |
from typing import List | |
import torch | |
import torch.nn as nn | |
from fairseq import utils | |
from fairseq.data import Dictionary | |
from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig | |
from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator | |
from fairseq.data.audio.speech_to_text_dataset import ( | |
SpeechToTextDataset, | |
TextTargetMultitaskData, | |
) | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
from fairseq.tasks.speech_to_text import DummyMultiTask | |
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion | |
logger = logging.getLogger(__name__) | |
class StackUnitSequenceGenerator(nn.Module): | |
def __init__(self, tgt_dict, vocab_size): | |
super().__init__() | |
self.pad = tgt_dict.pad() | |
self.eos = tgt_dict.eos() | |
self.unk = tgt_dict.unk() | |
self.offset = len(tgt_dict) - vocab_size | |
self.vocab_size = vocab_size | |
def pack_units(self, input: torch.Tensor, n_frames_per_step) -> torch.Tensor: | |
if n_frames_per_step <= 1: | |
return input | |
bsz, _, n = input.shape | |
assert n == n_frames_per_step | |
scale = [ | |
pow(self.vocab_size, n_frames_per_step - 1 - i) | |
for i in range(n_frames_per_step) | |
] | |
scale = torch.LongTensor(scale).squeeze(0).to(input.device) | |
mask = input >= self.offset | |
res = ((input - self.offset) * scale * mask).sum(dim=2) + self.offset | |
return res | |
def generate(self, models, sample, **kwargs): | |
# currently only support viterbi search for stacked units | |
model = models[0] | |
model.eval() | |
max_len = model.max_decoder_positions() | |
# TODO: incorporate max_len_a and max_len_b | |
src_tokens = sample["net_input"]["src_tokens"] | |
src_lengths = sample["net_input"]["src_lengths"] | |
bsz, src_len, _ = src_tokens.size() | |
n_frames_per_step = model.decoder.n_frames_per_step | |
# initialize | |
encoder_out = model.forward_encoder( | |
src_tokens, src_lengths, speaker=sample["speaker"] | |
) | |
incremental_state = {} | |
pred_out, attn, scores = [], [], [] | |
finished = src_tokens.new_zeros((bsz,)).bool() | |
prev_output_tokens = src_lengths.new_zeros((bsz, 1)).long().fill_(self.eos) | |
for _ in range(max_len): | |
cur_out, cur_extra = model.forward_decoder( | |
prev_output_tokens, | |
encoder_out=encoder_out, | |
incremental_state=incremental_state, | |
) | |
lprobs = model.get_normalized_probs([cur_out], log_probs=True) | |
# never select pad, unk | |
lprobs[:, :, self.pad] = -math.inf | |
lprobs[:, :, self.unk] = -math.inf | |
cur_pred_lprob, cur_pred_out = torch.max(lprobs, dim=2) | |
scores.append(cur_pred_lprob) | |
pred_out.append(cur_pred_out) | |
prev_output_tokens = torch.cat( | |
( | |
prev_output_tokens, | |
self.pack_units( | |
cur_pred_out.view(bsz, 1, n_frames_per_step), n_frames_per_step | |
), | |
), | |
dim=1, | |
) | |
attn.append(cur_extra["attn"][0]) | |
cur_finished = torch.any(cur_pred_out.squeeze(1) == self.eos, dim=1) | |
finished = finished | cur_finished | |
if finished.sum().item() == bsz: | |
break | |
pred_out = torch.cat(pred_out, dim=1).view(bsz, -1) | |
attn = torch.cat(attn, dim=2) | |
alignment = attn.max(dim=1)[1] | |
attn = attn.repeat_interleave(n_frames_per_step, dim=2) | |
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) | |
scores = torch.cat(scores, dim=1) | |
eos_idx = (pred_out == self.eos).nonzero(as_tuple=True) | |
out_lens = src_lengths.new_zeros((bsz,)).long().fill_(max_len) | |
for b, l in zip(eos_idx[0], eos_idx[1]): | |
out_lens[b] = min(l, out_lens[b]) | |
hypos = [ | |
[ | |
{ | |
"tokens": pred_out[b, :out_len], | |
"attn": attn[b, :, :out_len], | |
"alignment": alignment[b, :out_len], | |
"positional_scores": scores[b, :out_len], | |
"score": utils.item(scores[b, :out_len].sum().data), | |
} | |
] | |
for b, out_len in zip(range(bsz), out_lens) | |
] | |
return hypos | |
class SpeechToSpeechTask(LegacyFairseqTask): | |
def add_args(cls, parser): | |
parser.add_argument("data", help="manifest root path") | |
parser.add_argument( | |
"--config-yaml", | |
type=str, | |
default="config.yaml", | |
help="Configuration YAML filename (under manifest root)", | |
) | |
parser.add_argument( | |
"--multitask-config-yaml", | |
type=str, | |
default=None, | |
help="Configuration YAML filename for the multitasks (under manifest root)", | |
) | |
parser.add_argument( | |
"--max-source-positions", | |
default=6000, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the source sequence", | |
) | |
parser.add_argument( | |
"--max-target-positions", | |
default=1024, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the target sequence", | |
) | |
parser.add_argument( | |
"--target-is-code", | |
action="store_true", | |
help="set if target is discrete unit instead of spectrogram", | |
) | |
parser.add_argument( | |
"--target-code-size", type=int, default=None, help="# discrete units" | |
) | |
parser.add_argument( | |
"--n-frames-per-step", | |
type=int, | |
default=1, | |
help="# stacked frames, use 0 for reduced discrete unit sequence", | |
) | |
parser.add_argument("--eval-inference", action="store_true") | |
parser.add_argument( | |
"--eval-args", | |
type=str, | |
default="{}", | |
help='generation args for speech-to-unit model , e.g., \'{"beam": 5, "max_len_a": 1}\', as JSON string', | |
) | |
parser.add_argument("--eos-prob-threshold", type=float, default=0.5) | |
parser.add_argument( | |
"--mcd-normalize-type", | |
type=str, | |
default="targ", | |
choices=["targ", "pred", "path"], | |
) | |
parser.add_argument( | |
"--vocoder", | |
type=str, | |
default="griffin_lim", | |
choices=["griffin_lim", "hifigan", "code_hifigan"], | |
) | |
parser.add_argument("--spec-bwd-max-iter", type=int, default=8) | |
parser.add_argument( | |
"--infer-target-lang", | |
type=str, | |
default="", | |
help="target language for inference", | |
) | |
def __init__(self, args, tgt_dict, infer_tgt_lang_id=None): | |
super().__init__(args) | |
self.tgt_dict = tgt_dict | |
self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) | |
self.multitask_tasks = {} | |
self.tgt_dict_mt = None | |
self.eos_token_mt = None | |
if getattr(args, "multitask_config_yaml", None) is not None: | |
multitask_cfg = MultitaskConfig( | |
Path(args.data) / args.multitask_config_yaml | |
) | |
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index | |
for i, (task_name, task_config) in enumerate( | |
multitask_cfg.get_all_tasks().items() | |
): | |
task_obj = DummyMultiTask( | |
task_config, | |
task_config.tgt_dict, | |
first_pass=i == first_pass_task_idx, | |
) | |
self.multitask_tasks[task_name] = task_obj | |
if task_obj.is_first_pass_decoder: | |
self.tgt_dict_mt = task_obj.target_dictionary | |
if task_config.prepend_bos_and_append_tgt_lang_tag: | |
self.eos_token_mt = task_config.eos_token | |
assert not isinstance(self.eos_token_mt, List) | |
if not self.eos_token_mt: | |
raise Warning( | |
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" | |
) | |
self._infer_tgt_lang_id = infer_tgt_lang_id | |
def setup_task(cls, args, **kwargs): | |
data_cfg = data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) | |
tgt_dict = None | |
infer_tgt_lang_id = None | |
if args.target_is_code: | |
if data_cfg.prepend_tgt_lang_tag_as_bos: | |
# dictionary with language tags | |
dict_path = Path(args.data) / data_cfg.vocab_filename | |
if not dict_path.is_file(): | |
raise FileNotFoundError( | |
f"Dict has to be provided when setting prepend_tgt_lang_tag_as_bos: true, but dict not found: {dict_path}" | |
) | |
tgt_dict = Dictionary.load(dict_path.as_posix()) | |
# target langauge for inference | |
if args.infer_target_lang != "": | |
tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format( | |
args.infer_target_lang | |
) | |
infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag) | |
assert infer_tgt_lang_id != tgt_dict.unk() | |
else: | |
assert args.target_code_size is not None | |
tgt_dict = Dictionary() | |
for i in range(args.target_code_size): | |
tgt_dict.add_symbol(str(i)) | |
logger.info(f"dictionary size: " f"{len(tgt_dict):,}") | |
if getattr(args, "train_subset", None) is not None: | |
if not all(s.startswith("train") for s in args.train_subset.split(",")): | |
raise ValueError('Train splits should be named like "train*".') | |
assert args.n_frames_per_step >= 1 | |
assert ( | |
not args.eval_inference | |
or (args.target_is_code and args.vocoder == "code_hifigan") | |
or (not args.target_is_code and args.vocoder != "code_hifigan") | |
) | |
return cls(args, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id) | |
def build_criterion(self, args): | |
from fairseq import criterions | |
if len(self.multitask_tasks) > 0: | |
if self.args.target_is_code and not args._name.startswith("speech_to_unit"): | |
raise ValueError( | |
"set --criterion speech_to_unit for speech-to-unit loss with multitask" | |
) | |
elif not self.args.target_is_code and not args._name.startswith( | |
"speech_to_spectrogram" | |
): | |
raise ValueError( | |
"set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask" | |
) | |
return criterions.build_criterion(args, self) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
self.datasets[split] = SpeechToSpeechDatasetCreator.from_tsv( | |
root=self.args.data, | |
data_cfg=self.data_cfg, | |
splits=split, | |
is_train_split=split.startswith("train"), | |
epoch=epoch, | |
seed=self.args.seed, | |
target_is_code=self.args.target_is_code, | |
tgt_dict=self.target_dictionary, | |
n_frames_per_step=self.args.n_frames_per_step, | |
multitask=self.multitask_tasks, | |
) | |
def target_dictionary(self): | |
return self.tgt_dict | |
def target_dictionary_mt(self): | |
return self.tgt_dict_mt | |
def source_dictionary(self): | |
return None | |
def max_positions(self): | |
return self.args.max_source_positions, self.args.max_target_positions | |
def build_model(self, args, from_checkpoint=False): | |
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel | |
args.input_channels = self.data_cfg.input_transformed_channels | |
args.target_speaker_embed = self.data_cfg.target_speaker_embed is not None | |
args.n_frames_per_step = self.args.n_frames_per_step | |
model = super().build_model(args, from_checkpoint) | |
if len(self.multitask_tasks) > 0: | |
from fairseq.models.speech_to_speech.s2s_transformer import ( | |
S2STransformerMultitaskModelBase, | |
) | |
assert isinstance(model, S2STransformerMultitaskModelBase) | |
if self.args.eval_inference: | |
self.eval_gen_args = json.loads(self.args.eval_args) | |
self.generator = self.build_generator( | |
[model], Namespace(**self.eval_gen_args) | |
) | |
return model | |
def build_generator_dual_decoder( | |
self, | |
models, | |
args, | |
extra_gen_cls_kwargs=None, | |
): | |
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( | |
MultiDecoderSequenceGenerator, | |
) | |
return MultiDecoderSequenceGenerator( | |
models, | |
self.target_dictionary, | |
self.target_dictionary_mt, | |
beam_size=max(1, getattr(args, "beam", 1)), | |
beam_size_mt=max(1, getattr(args, "beam_mt", 1)), | |
max_len_a=getattr(args, "max_len_a", 0), | |
max_len_b=getattr(args, "max_len_b", 200), | |
max_len_a_mt=getattr(args, "max_len_a_mt", 0), | |
max_len_b_mt=getattr(args, "max_len_b_mt", 200), | |
min_len=getattr(args, "min_len", 1), | |
normalize_scores=(not getattr(args, "unnormalized", False)), | |
len_penalty=getattr(args, "lenpen", 1), | |
unk_penalty=getattr(args, "unkpen", 0), | |
temperature=getattr(args, "temperature", 1.0), | |
match_source_len=getattr(args, "match_source_len", False), | |
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), | |
**extra_gen_cls_kwargs, | |
) | |
def build_generator( | |
self, | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=None, | |
): | |
if not self.args.target_is_code or self.args.eval_inference: | |
from fairseq.models.text_to_speech.vocoder import get_vocoder | |
self.vocoder = get_vocoder(self.args, self.data_cfg) | |
self.vocoder = ( | |
self.vocoder.cuda() | |
if torch.cuda.is_available() and not self.args.cpu | |
else self.vocoder.cpu() | |
) | |
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None | |
if self.args.target_is_code: | |
if self.args.n_frames_per_step == 1: | |
if has_dual_decoder: | |
seq_generator = self.build_generator_dual_decoder( | |
models, | |
args, | |
extra_gen_cls_kwargs=extra_gen_cls_kwargs, | |
) | |
else: | |
seq_generator = super().build_generator( | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=extra_gen_cls_kwargs, | |
) | |
else: | |
assert ( | |
getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1 | |
), "only support viterbi search for stacked units" | |
seq_generator = StackUnitSequenceGenerator( | |
self.tgt_dict, | |
self.args.target_code_size, | |
) | |
else: | |
if has_dual_decoder: | |
if getattr(args, "teacher_forcing", False): | |
raise NotImplementedError | |
else: | |
from fairseq.speech_generator import MultiDecoderSpeechGenerator | |
generator = MultiDecoderSpeechGenerator | |
lang_token_ids_aux = { | |
i | |
for s, i in self.tgt_dict_mt.indices.items() | |
if TextTargetMultitaskData.is_lang_tag(s) | |
} | |
if extra_gen_cls_kwargs is None: | |
extra_gen_cls_kwargs = {} | |
extra_gen_cls_kwargs[ | |
"symbols_to_strip_from_output" | |
] = lang_token_ids_aux | |
eos_id_mt = ( | |
self.tgt_dict_mt.index(self.eos_token_mt) | |
if self.eos_token_mt | |
else None | |
) | |
assert eos_id_mt != self.tgt_dict_mt.unk() | |
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt | |
seq_generator = generator( | |
models, | |
args, | |
self.vocoder, | |
self.data_cfg, | |
self.target_dictionary_mt, | |
max_iter=self.args.max_target_positions, | |
eos_prob_threshold=self.args.eos_prob_threshold, | |
**extra_gen_cls_kwargs, | |
) | |
else: | |
if getattr(args, "teacher_forcing", False): | |
from fairseq.speech_generator import ( | |
TeacherForcingAutoRegressiveSpeechGenerator, | |
) | |
generator = TeacherForcingAutoRegressiveSpeechGenerator | |
logger.info("Teacher forcing mode for generation") | |
else: | |
from fairseq.speech_generator import AutoRegressiveSpeechGenerator | |
generator = AutoRegressiveSpeechGenerator | |
seq_generator = generator( | |
models[0], | |
self.vocoder, | |
self.data_cfg, | |
max_iter=self.args.max_target_positions, | |
eos_prob_threshold=self.args.eos_prob_threshold, | |
) | |
return seq_generator | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
for task_name, task_obj in self.multitask_tasks.items(): | |
criterion.set_multitask_loss_weight( | |
task_name, task_obj.args.get_loss_weight(update_num) | |
) | |
if task_name in model.multitask_decoders: | |
model.multitask_decoders[task_name].train() | |
loss, sample_size, logging_output = super().train_step( | |
sample, model, criterion, optimizer, update_num, ignore_grad | |
) | |
return loss, sample_size, logging_output | |
def valid_step(self, sample, model, criterion): | |
for task_name in self.multitask_tasks.keys(): | |
if task_name in model.multitask_decoders: | |
model.multitask_decoders[task_name].eval() | |
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
if self.args.eval_inference: | |
hypos, inference_losses = self.valid_step_with_inference( | |
sample, model, self.generator | |
) | |
for k, v in inference_losses.items(): | |
assert k not in logging_output | |
logging_output[k] = v | |
return loss, sample_size, logging_output | |
def valid_step_with_inference(self, sample, model, generator): | |
if self.args.target_is_code: | |
hypos = generator.generate([model], sample) | |
tgt_lens = ( | |
sample["target_lengths"] - 1 | |
) * self.args.n_frames_per_step # strip <eos> | |
for b, (f, l) in enumerate(zip(sample["target"], tgt_lens)): | |
hypos[b][0]["targ_waveform"] = self.vocoder( | |
{"code": f[:l] - 4}, # remove <bos>, <pad>, <eos>, <unk> | |
dur_prediction=self.eval_gen_args.get("dur_prediction", False), | |
) | |
if len(hypos[b][0]["tokens"]) > 0: | |
hypos[b][0]["waveform"] = self.vocoder( | |
{"code": hypos[b][0]["tokens"] - 4}, | |
dur_prediction=self.eval_gen_args.get("dur_prediction", False), | |
) | |
else: | |
hypos[b][0]["waveform"] = torch.flip( | |
hypos[b][0]["targ_waveform"], dims=[0] | |
) | |
else: | |
hypos = [ | |
[hypo] for hypo in generator.generate(model, sample, has_targ=True) | |
] | |
losses = { | |
"mcd_loss": 0.0, | |
"targ_frames": 0.0, | |
"pred_frames": 0.0, | |
"path_frames": 0.0, | |
"nins": 0.0, | |
"ndel": 0.0, | |
} | |
rets = batch_mel_cepstral_distortion( | |
[hypo[0]["targ_waveform"] for hypo in hypos], | |
[hypo[0]["waveform"] for hypo in hypos], | |
self.data_cfg.output_sample_rate, | |
normalize_type=None, | |
) | |
for d, extra in rets: | |
pathmap = extra[-1] | |
losses["mcd_loss"] += d.item() | |
losses["targ_frames"] += pathmap.size(0) | |
losses["pred_frames"] += pathmap.size(1) | |
losses["path_frames"] += pathmap.sum().item() | |
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() | |
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() | |
losses["norm_frames"] = losses[ | |
f"{getattr(self.args, 'mcd_normalize_type', 'targ')}_frames" | |
] | |
return hypos, losses | |
def inference_step( | |
self, generator, models, sample, prefix_tokens=None, constraints=None | |
): | |
with torch.no_grad(): | |
if self._infer_tgt_lang_id is not None: | |
return generator.generate( | |
models, | |
sample, | |
prefix_tokens=prefix_tokens, | |
constraints=constraints, | |
bos_token=self._infer_tgt_lang_id, | |
) | |
else: | |
return super().inference_step( | |
generator, | |
models, | |
sample, | |
prefix_tokens=prefix_tokens, | |
constraints=constraints, | |
) | |