# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import librosa import torch import numpy as np from numpy import linalg as LA def extract_energy_rmse( audio_ref, audio_deg, n_fft=1024, hop_length=256, win_length=1024, **kwargs, ): """Compute Energy Root Mean Square Error (RMSE) between the predicted and the ground truth audio. audio_ref: path to the ground truth audio. audio_deg: path to the predicted audio. fs: sampling rate. n_fft: fft size. hop_length: hop length. win_length: window length. method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio. "cut" will cut both audios into a same length according to the one with the shorter length. db_scale: the ground truth and predicted audio will be converted to db_scale if "True". """ # Load hyperparameters kwargs = kwargs["kwargs"] fs = kwargs["fs"] method = kwargs["method"] db_scale = kwargs["db_scale"] # Load audio if fs != None: audio_ref, _ = librosa.load(audio_ref, sr=fs) audio_deg, _ = librosa.load(audio_deg, sr=fs) else: audio_ref, fs = librosa.load(audio_ref) audio_deg, fs = librosa.load(audio_deg) # STFT spec_ref = librosa.stft( y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length ) spec_deg = librosa.stft( y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length ) # Get magnitudes mag_ref = np.abs(spec_ref).T mag_deg = np.abs(spec_deg).T # Convert spectrogram to energy energy_ref = LA.norm(mag_ref, axis=1) energy_deg = LA.norm(mag_deg, axis=1) # Convert to db_scale if db_scale: energy_ref = 20 * np.log10(energy_ref) energy_deg = 20 * np.log10(energy_deg) # Audio length alignment if method == "cut": length = min(len(energy_ref), len(energy_deg)) energy_ref = energy_ref[:length] energy_deg = energy_deg[:length] elif method == "dtw": _, wp = librosa.sequence.dtw(energy_ref, energy_deg, backtrack=True) energy_gt_new = [] energy_pred_new = [] for i in range(wp.shape[0]): gt_index = wp[i][0] pred_index = wp[i][1] energy_gt_new.append(energy_ref[gt_index]) energy_pred_new.append(energy_deg[pred_index]) energy_ref = np.array(energy_gt_new) energy_deg = np.array(energy_pred_new) assert len(energy_ref) == len(energy_deg) # Compute RMSE energy_mse = np.square(np.subtract(energy_ref, energy_deg)).mean() energy_rmse = math.sqrt(energy_mse) return energy_rmse