Spaces:
Running
Running
# Copyright (c) 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the LICENSE file in | |
# the root directory of this source tree. An additional grant of patent rights | |
# can be found in the PATENTS file in the same directory. | |
import logging | |
import os | |
import sys | |
from typing import Dict, List, Optional, Tuple | |
import numpy as np | |
from dataclasses import dataclass, field | |
from fairseq.data import Dictionary, HubertDataset | |
from fairseq.dataclass.configs import FairseqDataclass | |
from fairseq.tasks import register_task | |
from fairseq.tasks.fairseq_task import FairseqTask | |
from omegaconf import MISSING | |
logger = logging.getLogger(__name__) | |
class LabelEncoder(object): | |
def __init__(self, dictionary: Dictionary) -> None: | |
self.dictionary = dictionary | |
def __call__(self, label: str) -> List[str]: | |
return self.dictionary.encode_line( | |
label, | |
append_eos=False, | |
add_if_not_exist=False, | |
) | |
class HubertPretrainingConfig(FairseqDataclass): | |
data: str = field(default=MISSING, metadata={"help": "path to data directory"}) | |
fine_tuning: bool = field( | |
default=False, metadata={"help": "set to true if fine-tuning Hubert"} | |
) | |
labels: List[str] = field( | |
default_factory=lambda: ["ltr"], | |
metadata={ | |
"help": ( | |
"extension of the label files to load, frame-level labels for" | |
" pre-training, and sequence-level label for fine-tuning" | |
) | |
}, | |
) | |
label_dir: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "if set, looks for labels in this directory instead", | |
}, | |
) | |
label_rate: float = field( | |
default=-1.0, | |
metadata={"help": "label frame rate. -1.0 for sequence label"}, | |
) | |
sample_rate: int = field( | |
default=16_000, | |
metadata={ | |
"help": "target sample rate. audio files will be up/down " | |
"sampled to this rate" | |
}, | |
) | |
normalize: bool = field( | |
default=False, | |
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, | |
) | |
enable_padding: bool = field( | |
default=False, | |
metadata={"help": "pad shorter samples instead of cropping"}, | |
) | |
max_keep_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "exclude sample longer than this"}, | |
) | |
max_sample_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "max sample size to crop to for batching"}, | |
) | |
min_sample_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "min sample size to crop to for batching"}, | |
) | |
single_target: Optional[bool] = field( | |
default=False, | |
metadata={ | |
"help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset" | |
}, | |
) | |
random_crop: Optional[bool] = field( | |
default=True, | |
metadata={"help": "always crop from the beginning if false"}, | |
) | |
pad_audio: Optional[bool] = field( | |
default=False, | |
metadata={"help": "pad audio to the longest one in the batch if true"}, | |
) | |
class HubertPretrainingTask(FairseqTask): | |
cfg: HubertPretrainingConfig | |
def __init__( | |
self, | |
cfg: HubertPretrainingConfig, | |
) -> None: | |
super().__init__(cfg) | |
logger.info(f"current directory is {os.getcwd()}") | |
logger.info(f"HubertPretrainingTask Config {cfg}") | |
self.cfg = cfg | |
self.fine_tuning = cfg.fine_tuning | |
if cfg.fine_tuning: | |
self.state.add_factory("target_dictionary", self.load_dictionaries) | |
else: | |
self.state.add_factory("dictionaries", self.load_dictionaries) | |
self.blank_symbol = "<s>" | |
def source_dictionary(self) -> Optional[Dictionary]: | |
return None | |
def target_dictionary(self) -> Optional[Dictionary]: | |
return self.state.target_dictionary | |
def dictionaries(self) -> List[Dictionary]: | |
return self.state.dictionaries | |
def setup_task( | |
cls, cfg: HubertPretrainingConfig, **kwargs | |
) -> "HubertPretrainingTask": | |
return cls(cfg) | |
def load_dictionaries(self): | |
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir | |
dictionaries = [ | |
Dictionary.load(f"{label_dir}/dict.{label}.txt") | |
for label in self.cfg.labels | |
] | |
return dictionaries[0] if self.cfg.fine_tuning else dictionaries | |
def get_label_dir(self) -> str: | |
if self.cfg.label_dir is None: | |
return self.cfg.data | |
return self.cfg.label_dir | |
def load_dataset(self, split: str, **kwargs) -> None: | |
manifest = f"{self.cfg.data}/{split}.tsv" | |
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries | |
pad_list = [dict.pad() for dict in dicts] | |
eos_list = [dict.eos() for dict in dicts] | |
procs = [LabelEncoder(dict) for dict in dicts] | |
paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] | |
# hubert v1: pad_audio=True, random_crop=False; | |
self.datasets[split] = HubertDataset( | |
manifest, | |
sample_rate=self.cfg.sample_rate, | |
label_paths=paths, | |
label_rates=self.cfg.label_rate, | |
pad_list=pad_list, | |
eos_list=eos_list, | |
label_processors=procs, | |
max_keep_sample_size=self.cfg.max_keep_size, | |
min_keep_sample_size=self.cfg.min_sample_size, | |
max_sample_size=self.cfg.max_sample_size, | |
pad_audio=self.cfg.pad_audio, | |
normalize=self.cfg.normalize, | |
store_labels=False, | |
random_crop=self.cfg.random_crop, | |
single_target=self.cfg.single_target, | |
) | |
def max_positions(self) -> Tuple[int, int]: | |
return (sys.maxsize, sys.maxsize) | |
def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: | |
return indices | |