Spaces:
Build error
Build error
from itertools import count, islice | |
from typing import Any, Iterable, Literal, Optional, TypeVar, Union, overload, Dict, List, Tuple | |
from collections import defaultdict | |
import json | |
import spaces | |
import torch | |
from datasets import Dataset, Audio | |
from dataspeech import rate_apply, pitch_apply, snr_apply, squim_apply | |
from metadata_to_text import bins_to_text, speaker_level_relative_to_gender | |
Row = Dict[str, Any] | |
T = TypeVar("T") | |
BATCH_SIZE = 20 | |
def batched(it: Iterable[T], n: int) -> Iterable[List[T]]: | |
... | |
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[List[T]]: | |
... | |
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[Tuple[List[int], List[T]]]: | |
... | |
def batched( | |
it: Iterable[T], n: int, with_indices: bool = False | |
) -> Union[Iterable[List[T]], Iterable[Tuple[List[int], List[T]]]]: | |
it, indices = iter(it), count() | |
while batch := list(islice(it, n)): | |
yield (list(islice(indices, len(batch))), batch) if with_indices else batch | |
def analyze( | |
batch: List[Dict[str, Any]], | |
audio_column_name: str, text_column_name: str, | |
cache: Optional[Dict[str, List[Any]]] = None, | |
) -> List[List[Any]]: | |
cache = {} if cache is None else cache | |
# TODO: add speaker and gender to app | |
speaker_id_column_name = "speaker_id" | |
gender_column_name = "gender" | |
tmp_dict = defaultdict(list) | |
for sample in batch: | |
for key in sample: | |
if key in [audio_column_name, text_column_name, speaker_id_column_name, gender_column_name]: | |
tmp_dict[key].append(sample[key]) if key != audio_column_name else tmp_dict[key].append(sample[key][0]["src"]) | |
tmp_dataset = Dataset.from_dict(tmp_dict).cast_column(audio_column_name, Audio()) | |
## 1. Extract continous tags | |
squim_dataset = tmp_dataset.map( | |
squim_apply, | |
batched=True, | |
batch_size=BATCH_SIZE, | |
with_rank=True if torch.cuda.device_count()>0 else False, | |
num_proc=torch.cuda.device_count(), | |
remove_columns=[audio_column_name], # tricks to avoid rewritting audio | |
fn_kwargs={"audio_column_name": audio_column_name,}, | |
) | |
pitch_dataset = tmp_dataset.map( | |
pitch_apply, | |
batched=True, | |
batch_size=BATCH_SIZE, | |
with_rank=True if torch.cuda.device_count()>0 else False, | |
num_proc=torch.cuda.device_count(), | |
remove_columns=[audio_column_name], # tricks to avoid rewritting audio | |
fn_kwargs={"audio_column_name": audio_column_name, "penn_batch_size": 4096}, | |
) | |
snr_dataset = tmp_dataset.map( | |
snr_apply, | |
batched=True, | |
batch_size=BATCH_SIZE, | |
with_rank=True if torch.cuda.device_count()>0 else False, | |
num_proc=torch.cuda.device_count(), | |
remove_columns=[audio_column_name], # tricks to avoid rewritting audio | |
fn_kwargs={"audio_column_name": audio_column_name}, | |
) | |
rate_dataset = tmp_dataset.map( | |
rate_apply, | |
with_rank=False, | |
num_proc=1, | |
remove_columns=[audio_column_name], # tricks to avoid rewritting audio | |
fn_kwargs={"audio_column_name": audio_column_name, "text_column_name": text_column_name}, | |
) | |
enriched_dataset = pitch_dataset.add_column("snr", snr_dataset["snr"]).add_column("c50", snr_dataset["c50"]) | |
enriched_dataset = enriched_dataset.add_column("speaking_rate", rate_dataset["speaking_rate"]).add_column("phonemes", rate_dataset["phonemes"]) | |
enriched_dataset = enriched_dataset.add_column("stoi", squim_dataset["stoi"]).add_column("si-sdr", squim_dataset["sdr"]).add_column("pesq", squim_dataset["pesq"]) | |
## 2. Map continuous tags to text tags | |
text_bins_dict = {} | |
with open("./v01_text_bins.json") as json_file: | |
text_bins_dict = json.load(json_file) | |
bin_edges_dict = {} | |
with open("./v01_bin_edges.json") as json_file: | |
bin_edges_dict = json.load(json_file) | |
speaker_level_pitch_bins = text_bins_dict.get("speaker_level_pitch_bins") | |
speaker_rate_bins = text_bins_dict.get("speaker_rate_bins") | |
snr_bins = text_bins_dict.get("snr_bins") | |
reverberation_bins = text_bins_dict.get("reverberation_bins") | |
utterance_level_std = text_bins_dict.get("utterance_level_std") | |
enriched_dataset = [enriched_dataset] | |
if "gender" in batch[0] and "speaker_id" in batch[0]: | |
bin_edges = None | |
if "pitch_bins_male" in bin_edges_dict and "pitch_bins_female" in bin_edges_dict: | |
bin_edges = {"male": bin_edges_dict["pitch_bins_male"], "female": bin_edges_dict["pitch_bins_female"]} | |
enriched_dataset, _ = speaker_level_relative_to_gender(enriched_dataset, speaker_level_pitch_bins, "speaker_id", "gender", "utterance_pitch_mean", "pitch", batch_size=20, num_workers=1, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges) | |
enriched_dataset, _ = bins_to_text(enriched_dataset, speaker_rate_bins, "speaking_rate", "speaking_rate", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speaking_rate",None)) | |
enriched_dataset, _ = bins_to_text(enriched_dataset, snr_bins, "snr", "noise", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("noise",None), lower_range=None) | |
enriched_dataset, _ = bins_to_text(enriched_dataset, reverberation_bins, "c50", "reverberation", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("reverberation",None)) | |
enriched_dataset, _ = bins_to_text(enriched_dataset, utterance_level_std, "utterance_pitch_std", "speech_monotony", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speech_monotony",None)) | |
enriched_dataset = enriched_dataset[0] | |
for i,sample in enumerate(batch): | |
new_sample = {} | |
new_sample[audio_column_name] = f"<audio src='{sample[audio_column_name][0]['src']}' controls></audio>" | |
for col in ["speaking_rate", "reverberation", "noise", "speech_monotony", "c50", "snr", "stoi", "pesq", "si-sdr"]: # phonemes, speaking_rate, utterance_pitch_std, utterance_pitch_mean | |
new_sample[col] = enriched_dataset[col][i] | |
if "gender" in batch[0] and "speaker_id" in batch[0]: | |
new_sample["pitch"] = enriched_dataset["pitch"][i] | |
new_sample[gender_column_name] = sample[col] | |
new_sample[speaker_id_column_name] = sample[col] | |
new_sample[text_column_name] = sample[text_column_name] | |
batch[i] = new_sample | |
return batch | |
def run_dataspeech( | |
rows: Iterable[Row], audio_column_name: str, text_column_name: str | |
) -> Iterable[Any]: | |
cache: Dict[str, List[Any]] = {} | |
for batch in batched(rows, BATCH_SIZE): | |
yield analyze( | |
batch=batch, | |
audio_column_name=audio_column_name, | |
text_column_name=text_column_name, | |
cache=cache, | |
) | |