Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import os | |
import os.path as osp | |
import tqdm | |
import torch | |
import torch.nn.functional as F | |
from shutil import copyfile | |
from npy_append_array import NpyAppendArray | |
import fairseq | |
import soundfile as sf | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description="compute kmeans codebook from kaldi-computed feats" | |
) | |
# fmt: off | |
parser.add_argument('data', help='location of tsv files') | |
parser.add_argument('--split', help='which split to read', required=True) | |
parser.add_argument('--save-dir', help='where to save the output', required=True) | |
parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec ctc model', required=True) | |
parser.add_argument('--layer', type=int, default=14, help='which layer to use') | |
# fmt: on | |
return parser | |
class Wav2VecFeatureReader(object): | |
def __init__(self, cp_file, layer): | |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( | |
[cp_file] | |
) | |
model = model[0] | |
model.eval() | |
model.cuda() | |
self.model = model | |
self.task = task | |
self.layer = layer | |
def read_audio(self, fname): | |
"""Load an audio file and return PCM along with the sample rate""" | |
wav, sr = sf.read(fname) | |
assert sr == 16e3 | |
return wav | |
def get_feats(self, loc): | |
x = self.read_audio(loc) | |
with torch.no_grad(): | |
source = torch.from_numpy(x).float().cuda() | |
if self.task.cfg.normalize: | |
assert source.dim() == 1, source.dim() | |
with torch.no_grad(): | |
source = F.layer_norm(source, source.shape) | |
source = source.view(1, -1) | |
m_res = self.model(source=source, mask=False, features_only=True, layer=self.layer) | |
return m_res["x"].squeeze(0).cpu() | |
def get_iterator(args): | |
with open(osp.join(args.data, args.split) + ".tsv", "r") as fp: | |
lines = fp.read().split("\n") | |
root = lines.pop(0).strip() | |
files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] | |
num = len(files) | |
reader = Wav2VecFeatureReader(args.checkpoint, args.layer) | |
def iterate(): | |
for fname in files: | |
w2v_feats = reader.get_feats(fname) | |
yield w2v_feats | |
return iterate, num | |
def main(): | |
parser = get_parser() | |
args = parser.parse_args() | |
os.makedirs(args.save_dir, exist_ok=True) | |
def create_files(dest): | |
copyfile(osp.join(args.data, args.split) + ".tsv", dest + ".tsv") | |
if osp.exists(osp.join(args.data, args.split) + ".wrd"): | |
copyfile(osp.join(args.data, args.split) + ".wrd", dest + ".wrd") | |
if osp.exists(osp.join(args.data, args.split) + ".phn"): | |
copyfile(osp.join(args.data, args.split) + ".phn", dest + ".phn") | |
if osp.exists(dest + ".npy"): | |
os.remove(dest + ".npy") | |
npaa = NpyAppendArray(dest + ".npy") | |
return npaa | |
save_path = osp.join(args.save_dir, args.split) | |
npaa = create_files(save_path) | |
generator, num = get_iterator(args) | |
iterator = generator() | |
with open(save_path + ".lengths", "w") as l_f: | |
for w2v_feats in tqdm.tqdm(iterator, total=num): | |
print(len(w2v_feats), file=l_f) | |
if len(w2v_feats) > 0: | |
npaa.append(w2v_feats.numpy()) | |
if __name__ == "__main__": | |
main() | |