Spaces:
Runtime error
Runtime error
File size: 2,372 Bytes
8c92a11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import librosa
import numpy as np
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
def extract_stoi(audio_ref, audio_deg, **kwargs):
"""Compute Short-Time Objective Intelligibility between the predicted and the ground truth audio.
audio_ref: path to the ground truth audio.
audio_deg: path to the predicted audio.
fs: sampling rate.
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
"cut" will cut both audios into a same length according to the one with the shorter length.
"""
# Load hyperparameters
kwargs = kwargs["kwargs"]
fs = kwargs["fs"]
method = kwargs["method"]
# Load audio
if fs != None:
audio_ref, _ = librosa.load(audio_ref, sr=fs)
audio_deg, _ = librosa.load(audio_deg, sr=fs)
else:
audio_ref, fs = librosa.load(audio_ref)
audio_deg, fs = librosa.load(audio_deg)
# Initialize method
stoi = ShortTimeObjectiveIntelligibility(fs, extended=False)
# Audio length alignment
if len(audio_ref) != len(audio_deg):
if method == "cut":
length = min(len(audio_ref), len(audio_deg))
audio_ref = audio_ref[:length]
audio_deg = audio_deg[:length]
elif method == "dtw":
_, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
audio_ref_new = []
audio_deg_new = []
for i in range(wp.shape[0]):
ref_index = wp[i][0]
deg_index = wp[i][1]
audio_ref_new.append(audio_ref[ref_index])
audio_deg_new.append(audio_deg[deg_index])
audio_ref = np.array(audio_ref_new)
audio_deg = np.array(audio_deg_new)
assert len(audio_ref) == len(audio_deg)
# Convert to tensor
audio_ref = torch.from_numpy(audio_ref)
audio_deg = torch.from_numpy(audio_deg)
if torch.cuda.is_available():
device = torch.device("cuda")
audio_ref = audio_ref.to(device)
audio_deg = audio_deg.to(device)
stoi = stoi.to(device)
return stoi(audio_deg, audio_ref).detach().cpu().numpy().tolist()
|