import os.path import random import numpy as np import torch import re import torch.utils.data import json import kaldiio from tqdm import tqdm from text import text_to_sequence class BaseLoader(torch.utils.data.Dataset): def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str): """ :param utts: file path. A list of utts for this loader. These are the only utts that this loader has access. This loader only deals with text, duration and feats. Other files despite `utts` can be larger. """ self.n_mel_channels = hparams.n_mel_channels self.sampling_rate = hparams.sampling_rate self.utts = self.get_utts(utts) self.utt2feat = self.get_utt2feat(feats_scp) self.utt2text = self.get_utt2text(utt2text) def get_utts(self, utts: str) -> list: with open(utts, 'r') as f: L = f.readlines() L = list(map(lambda x: x.strip(), L)) random.seed(1234) random.shuffle(L) return L def get_utt2feat(self, feats_scp: str): utt2feat = kaldiio.load_scp(feats_scp) # lazy load mode print(f"Succeed reading feats from {feats_scp}") return utt2feat def get_utt2text(self, utt2text: str): with open(utt2text, 'r') as f: L = f.readlines() utt2text = {line.split()[0]: line.strip().split(" ", 1)[1] for line in L} return utt2text def get_mel_from_kaldi(self, utt): feat = self.utt2feat[utt] feat = torch.FloatTensor(feat).squeeze() assert self.n_mel_channels in feat.shape if feat.shape[0] == self.n_mel_channels: return feat else: return feat.T def get_text(self, utt): text = self.utt2text[utt] text_norm = text_to_sequence(text) text_norm = torch.IntTensor(text_norm) return text_norm def __getitem__(self, index): res = self.get_mel_text_pair(self.utts[index]) return res def __len__(self): return len(self.utts) def sample_test_batch(self, size): idx = np.random.choice(range(len(self)), size=size, replace=False) test_batch = [] for index in idx: test_batch.append(self.__getitem__(index)) return test_batch class SpkIDLoader(BaseLoader): def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, utt2phn_duration: str, utt2spk: str): """ :param utt2spk: json file path (utt name -> spk id) This loader loads speaker as a speaker ID for embedding table """ super(SpkIDLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) self.utt2spk = self.get_utt2spk(utt2spk) def get_utt2spk(self, utt2spk: str) -> dict: with open(utt2spk, 'r') as f: res = json.load(f) return res def get_mel_text_pair(self, utt): # separate filename and text spkid = self.utt2spk[utt] phn_ids = self.get_text(utt) mel = self.get_mel_from_kaldi(utt) dur = self.get_dur_from_kaldi(utt) assert sum(dur) == mel.shape[1], f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}" res = { "utt": utt, "mel": mel, "spk_ids": spkid } return res def __getitem__(self, index): res = self.get_mel_text_pair(self.utts[index]) return res def __len__(self): return len(self.utts) class SpkIDLoaderWithEmo(BaseLoader): def __init__(self, utts: str, hparams, feats_scp: str, utt2text:str, utt2spk: str, utt2emo: str): """ :param utt2spk: json file path (utt name -> spk id) This loader loads speaker as a speaker ID for embedding table """ super(SpkIDLoaderWithEmo, self).__init__(utts, hparams, feats_scp, utt2text) self.utt2spk = self.get_utt2spk(utt2spk) self.utt2emo = self.get_utt2emo(utt2emo) def get_utt2spk(self, utt2spk: str) -> dict: with open(utt2spk, 'r') as f: res = json.load(f) return res def get_utt2emo(self, utt2emo: str) -> dict: with open(utt2emo, 'r') as f: res = json.load(f) return res def get_mel_text_pair(self, utt): # separate filename and text spkid = int(self.utt2spk[utt]) emoid = int(self.utt2emo[utt]) text = self.get_text(utt) mel = self.get_mel_from_kaldi(utt) res = { "utt": utt, "text": text, "mel": mel, "spk_ids": spkid, "emo_ids": emoid } return res def __getitem__(self, index): res = self.get_mel_text_pair(self.utts[index]) return res def __len__(self): return len(self.utts) class SpkIDLoaderWithPE(SpkIDLoader): def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, utt2phn_duration: str, utt2spk: str, var_scp: str): """ This loader loads speaker ID together with variance (4-dim pitch, 1-dim energy) """ super(SpkIDLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration, utt2spk) self.utt2var = self.get_utt2var(var_scp) def get_utt2var(self, utt2var: str) -> dict: res = kaldiio.load_scp(utt2var) print(f"Succeed reading feats from {utt2var}") return res def get_var_from_kaldi(self, utt): var = self.utt2var[utt] var = torch.FloatTensor(var).squeeze() assert 5 in var.shape if var.shape[0] == 5: return var else: return var.T def get_mel_text_pair(self, utt): # separate filename and text spkid = self.utt2spk[utt] phn_ids = self.get_text(utt) mel = self.get_mel_from_kaldi(utt) dur = self.get_dur_from_kaldi(utt) var = self.get_var_from_kaldi(utt) assert sum(dur) == mel.shape[1] == var.shape[1], \ f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}" res = { "utt": utt, "phn_ids": phn_ids, "mel": mel, "dur": dur, "spk_ids": spkid, "var": var } return res class XvectorLoader(BaseLoader): def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str): """ :param utt2spk_name: like kaldi-style utt2spk :param spk_xvector_scp: kaldi-style speaker-level xvector.scp """ super(XvectorLoader, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) self.utt2spk = self.get_utt2spk(utt2spk_name) self.spk2xvector = self.get_spk2xvector(spk_xvector_scp) def get_utt2spk(self, utt2spk): res = dict() with open(utt2spk, 'r') as f: for l in f.readlines(): res[l.split()[0]] = l.split()[1] return res def get_spk2xvector(self, spk_xvector_scp: str) -> dict: res = kaldiio.load_scp(spk_xvector_scp) print(f"Succeed reading xvector from {spk_xvector_scp}") return res def get_xvector(self, utt): xv = self.spk2xvector[self.utt2spk[utt]] xv = torch.FloatTensor(xv).squeeze() return xv def get_mel_text_pair(self, utt): phn_ids = self.get_text(utt) mel = self.get_mel_from_kaldi(utt) dur = self.get_dur_from_kaldi(utt) xvector = self.get_xvector(utt) assert sum(dur) == mel.shape[1], \ f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}" res = { "utt": utt, "phn_ids": phn_ids, "mel": mel, "dur": dur, "xvector": xvector, } return res class XvectorLoaderWithPE(BaseLoader): def __init__(self, utts: str, hparams, feats_scp: str, utt2phns: str, phn2id: str, utt2phn_duration: str, utt2spk_name: str, spk_xvector_scp: str, var_scp: str): super(XvectorLoaderWithPE, self).__init__(utts, hparams, feats_scp, utt2phns, phn2id, utt2phn_duration) self.utt2spk = self.get_utt2spk(utt2spk_name) self.spk2xvector = self.get_spk2xvector(spk_xvector_scp) self.utt2var = self.get_utt2var(var_scp) def get_spk2xvector(self, spk_xvector_scp: str) -> dict: res = kaldiio.load_scp(spk_xvector_scp) print(f"Succeed reading xvector from {spk_xvector_scp}") return res def get_utt2spk(self, utt2spk): res = dict() with open(utt2spk, 'r') as f: for l in f.readlines(): res[l.split()[0]] = l.split()[1] return res def get_utt2var(self, utt2var: str) -> dict: res = kaldiio.load_scp(utt2var) print(f"Succeed reading feats from {utt2var}") return res def get_var_from_kaldi(self, utt): var = self.utt2var[utt] var = torch.FloatTensor(var).squeeze() assert 5 in var.shape if var.shape[0] == 5: return var else: return var.T def get_xvector(self, utt): xv = self.spk2xvector[self.utt2spk[utt]] xv = torch.FloatTensor(xv).squeeze() return xv def get_mel_text_pair(self, utt): # separate filename and text spkid = self.utt2spk[utt] phn_ids = self.get_text(utt) mel = self.get_mel_from_kaldi(utt) dur = self.get_dur_from_kaldi(utt) var = self.get_var_from_kaldi(utt) xvector = self.get_xvector(utt) assert sum(dur) == mel.shape[1] == var.shape[1], \ f"Frame length mismatch: utt {utt}, dur: {sum(dur)}, mel: {mel.shape[1]}, var: {var.shape[1]}" res = { "utt": utt, "phn_ids": phn_ids, "mel": mel, "dur": dur, "spk_ids": spkid, "var": var, "xvector": xvector } return res