File size: 2,826 Bytes
8c92a11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import librosa
import torch
import numpy as np
from numpy import linalg as LA
def extract_energy_rmse(
audio_ref,
audio_deg,
n_fft=1024,
hop_length=256,
win_length=1024,
**kwargs,
):
"""Compute Energy Root Mean Square Error (RMSE) between the predicted and the ground truth audio.
audio_ref: path to the ground truth audio.
audio_deg: path to the predicted audio.
fs: sampling rate.
n_fft: fft size.
hop_length: hop length.
win_length: window length.
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
"cut" will cut both audios into a same length according to the one with the shorter length.
db_scale: the ground truth and predicted audio will be converted to db_scale if "True".
"""
# Load hyperparameters
kwargs = kwargs["kwargs"]
fs = kwargs["fs"]
method = kwargs["method"]
db_scale = kwargs["db_scale"]
# Load audio
if fs != None:
audio_ref, _ = librosa.load(audio_ref, sr=fs)
audio_deg, _ = librosa.load(audio_deg, sr=fs)
else:
audio_ref, fs = librosa.load(audio_ref)
audio_deg, fs = librosa.load(audio_deg)
# STFT
spec_ref = librosa.stft(
y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
)
spec_deg = librosa.stft(
y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
)
# Get magnitudes
mag_ref = np.abs(spec_ref).T
mag_deg = np.abs(spec_deg).T
# Convert spectrogram to energy
energy_ref = LA.norm(mag_ref, axis=1)
energy_deg = LA.norm(mag_deg, axis=1)
# Convert to db_scale
if db_scale:
energy_ref = 20 * np.log10(energy_ref)
energy_deg = 20 * np.log10(energy_deg)
# Audio length alignment
if method == "cut":
length = min(len(energy_ref), len(energy_deg))
energy_ref = energy_ref[:length]
energy_deg = energy_deg[:length]
elif method == "dtw":
_, wp = librosa.sequence.dtw(energy_ref, energy_deg, backtrack=True)
energy_gt_new = []
energy_pred_new = []
for i in range(wp.shape[0]):
gt_index = wp[i][0]
pred_index = wp[i][1]
energy_gt_new.append(energy_ref[gt_index])
energy_pred_new.append(energy_deg[pred_index])
energy_ref = np.array(energy_gt_new)
energy_deg = np.array(energy_pred_new)
assert len(energy_ref) == len(energy_deg)
# Compute RMSE
energy_mse = np.square(np.subtract(energy_ref, energy_deg)).mean()
energy_rmse = math.sqrt(energy_mse)
return energy_rmse
|