Spaces:
Running
Running
# Copyright (c) 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the LICENSE file in | |
# the root directory of this source tree. An additional grant of patent rights | |
# can be found in the PATENTS file in the same directory. | |
import logging | |
import sys | |
import torch | |
from dataclasses import dataclass, field | |
from typing import List, Optional, Tuple | |
from fairseq.data import Dictionary | |
from fairseq.data.codedataset import ExpressiveCodeDataConfig, CodeDataset | |
from fairseq.dataclass.configs import FairseqDataclass | |
from fairseq.tasks import register_task | |
from fairseq.tasks.fairseq_task import FairseqTask | |
from omegaconf import MISSING, DictConfig | |
logger = logging.getLogger(__name__) | |
class UnitDictionary(Dictionary): | |
""" | |
A fixed-sized Dictionary that operates on integer-valued tokens | |
wth a trivial (identity) token <-> id mapping. | |
Special symbols (bos, eos, ...) have ids above n_units. | |
""" | |
def __init__( | |
self, | |
*, # begin keyword-only arguments | |
n_units, | |
bos="<s>", | |
pad="<pad>", | |
eos="</s>", | |
unk="<unk>", | |
extra_special_symbols=None, | |
clip=False, | |
): | |
self.n_units = n_units | |
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos | |
self.clip = clip | |
self.symbols = [] | |
self.count = [] | |
self.indices = {} | |
for i in range(n_units): | |
self.add_symbol(str(i)) | |
self.bos_index = self.add_symbol(bos) | |
self.pad_index = self.add_symbol(pad) | |
self.eos_index = self.add_symbol(eos) | |
self.unk_index = self.add_symbol(unk) | |
if extra_special_symbols: | |
for s in extra_special_symbols: | |
self.add_symbol(s) | |
self.nspecial = len(self.symbols) | |
def encode_line(self, line, append_eos=True, prepend_bos=False) -> torch.IntTensor: | |
words = [int(x) for x in line.split()] | |
if self.clip: | |
words = [min(self.n_units - 1, word) for word in words] | |
if prepend_bos: | |
words = [self.bos_index] + words | |
if append_eos: | |
words.append(self.eos_index) | |
ids = torch.IntTensor(words) | |
return ids | |
class SpeechUnitModelingConfig(FairseqDataclass): | |
data: str = field(default=MISSING, metadata={"help": "Path to data config.json"}) | |
max_token_duration: int = field( | |
default=20, metadata={"help": "all token durations are capped to this value"} | |
) | |
tokens_per_sample: int = field( | |
default=1024, metadata={"help": "tokens in a sample"} | |
) | |
max_target_positions: int = field( | |
default=1024, metadata={"help": "max target positions"} | |
) | |
# duration modeling | |
ignore_duration_input: bool = field( | |
default=False, metadata={"help": "whether token durations should be zeroed out"} | |
) | |
discrete_duration: bool = field( | |
default=False, metadata={"help": "treat duration as discrete variable"} | |
) | |
# F0 modeling | |
ignore_f0_input: bool = field( | |
default=False, metadata={"help": "whether F0 should be zeroed out"} | |
) | |
discrete_f0: bool = field( | |
default=False, metadata={"help": "load quantized f0. get bin from config"} | |
) | |
log_f0: bool = field( | |
default=False, metadata={"help": "whether f0 should be modeled in log space"} | |
) | |
normalize_f0_mean: bool = field( | |
default=False, metadata={"help": "whether normalize f0 by speaker mean"} | |
) | |
normalize_f0_std: bool = field( | |
default=False, metadata={"help": "whether normalize f0 by speaker stddev"} | |
) | |
interpolate_f0: bool = field( | |
default=False, | |
metadata={"help": "whether interpolate f0 for non-voiced segments"}, | |
) | |
# input/output streams | |
stream_shifts: str = field( | |
default="0,0", | |
metadata={ | |
"help": ( | |
"comma-separated integer list denoting right-shift for " | |
"duration and pitch streams" | |
) | |
}, | |
) | |
class SpeechUnitLanguageModelingTask(FairseqTask): | |
def __init__(self, cfg: SpeechUnitModelingConfig) -> None: | |
super().__init__(cfg) | |
assert not self.cfg.normalize_f0_std or self.cfg.normalize_f0_mean | |
self.data_config = ExpressiveCodeDataConfig(cfg.data) | |
self._source_dictionary = self._target_dictionary = UnitDictionary( | |
n_units=self.data_config.n_units | |
) | |
self._source_duration_dictionary = self._target_duration_dictionary = ( | |
UnitDictionary(n_units=self.cfg.max_token_duration + 1, clip=True) | |
if self.cfg.discrete_duration | |
else None | |
) | |
self._source_f0_dictionary = self._target_f0_dictionary = ( | |
UnitDictionary(n_units=self.data_config.f0_vq_n_units) | |
if self.cfg.discrete_f0 | |
else None | |
) | |
self._channel_names = ["token", "duration", "f0"] | |
self._channel_sizes = [ | |
len(self.target_dictionary), | |
len(self.target_duration_dictionary) if self.cfg.discrete_duration else 1, | |
len(self.target_f0_dictionary) if self.cfg.discrete_f0 else 1, | |
] | |
def source_dictionary(self) -> Optional[Dictionary]: | |
return self._source_dictionary | |
def source_duration_dictionary(self) -> Optional[Dictionary]: | |
return self._source_duration_dictionary | |
def source_f0_dictionary(self) -> Optional[Dictionary]: | |
return self._source_f0_dictionary | |
def channel_names(self) -> List[str]: | |
return self._channel_names | |
def channel_sizes(self) -> List[int]: | |
return self._channel_sizes | |
def dictionary(self) -> Optional[Dictionary]: | |
return self._source_dictionary | |
def target_dictionary(self) -> Optional[Dictionary]: | |
return self._target_dictionary | |
def target_duration_dictionary(self) -> Optional[Dictionary]: | |
return self._target_duration_dictionary | |
def target_f0_dictionary(self) -> Optional[Dictionary]: | |
return self._target_f0_dictionary | |
def dictionaries(self) -> List[Dictionary]: | |
return [self._dictionaries[l] for l in self.cfg.labels] | |
def setup_task( | |
cls, cfg: SpeechUnitModelingConfig, **kwargs | |
) -> "SpeechUnitLanguageModelingTask": | |
return cls(cfg) | |
def load_dataset(self, split: str, **kwargs) -> None: | |
self.datasets[split] = CodeDataset( | |
manifest=self.data_config.manifests[split], | |
dictionary=self.source_dictionary, | |
dur_dictionary=self.source_duration_dictionary, | |
f0_dictionary=self.source_f0_dictionary, | |
config=self.data_config, | |
discrete_dur=self.cfg.discrete_duration, | |
discrete_f0=self.cfg.discrete_f0, | |
log_f0=self.cfg.log_f0, | |
normalize_f0_mean=self.cfg.normalize_f0_mean, | |
normalize_f0_std=self.cfg.normalize_f0_std, | |
interpolate_f0=self.cfg.interpolate_f0, | |
shifts=self.cfg.stream_shifts, | |
) | |
def max_positions(self) -> Tuple[int, int]: | |
return (sys.maxsize, sys.maxsize) | |
def build_criterion(self, cfg: DictConfig): | |
import fairseq.criterions | |
return fairseq.criterions.build_criterion(cfg, self) | |