Spaces:
Runtime error
Runtime error
import torch | |
import torchaudio | |
import scipy.signal | |
import numpy as np | |
import pyloudnorm as pyln | |
import matplotlib.pyplot as plt | |
from deepafx_st.processors.dsp.compressor import compressor | |
from tqdm import tqdm | |
class BaselineEQ(torch.nn.Module): | |
def __init__( | |
self, | |
ntaps: int = 63, | |
n_fft: int = 65536, | |
sample_rate: float = 44100, | |
): | |
super().__init__() | |
self.ntaps = ntaps | |
self.n_fft = n_fft | |
self.sample_rate = sample_rate | |
# compute the target spectrum | |
# print("Computing target spectrum...") | |
# self.target_spec, self.sm_target_spec = self.analyze_speech_dataset(filepaths) | |
# self.plot_spectrum(self.target_spec, filename="targetEQ") | |
# self.plot_spectrum(self.sm_target_spec, filename="targetEQsm") | |
def forward(self, x, y): | |
bs, ch, s = x.size() | |
x = x.view(bs * ch, -1) | |
y = y.view(bs * ch, -1) | |
in_spec = self.get_average_spectrum(x) | |
ref_spec = self.get_average_spectrum(y) | |
sm_in_spec = self.smooth_spectrum(in_spec) | |
sm_ref_spec = self.smooth_spectrum(ref_spec) | |
# self.plot_spectrum(in_spec, filename="inSpec") | |
# self.plot_spectrum(sm_in_spec, filename="inSpecsm") | |
# design inverse FIR filter to match target EQ | |
freqs = np.linspace(0, 1.0, num=(self.n_fft // 2) + 1) | |
response = sm_ref_spec / sm_in_spec | |
response[-1] = 0.0 # zero gain at nyquist | |
b = scipy.signal.firwin2( | |
self.ntaps, | |
freqs * (self.sample_rate / 2), | |
response, | |
fs=self.sample_rate, | |
) | |
# scale the coefficients for less intense filter | |
# clearb *= 0.5 | |
# apply the filter | |
x_filt = scipy.signal.lfilter(b, [1.0], x.numpy()) | |
x_filt = torch.tensor(x_filt.astype("float32")) | |
if False: | |
# plot the filter response | |
w, h = scipy.signal.freqz(b, fs=self.sample_rate, worN=response.shape[-1]) | |
fig, ax1 = plt.subplots() | |
ax1.set_title("Digital filter frequency response") | |
ax1.plot(w, 20 * np.log10(abs(h + 1e-8))) | |
ax1.plot(w, 20 * np.log10(abs(response + 1e-8))) | |
ax1.set_xscale("log") | |
ax1.set_ylim([-12, 12]) | |
plt.grid(c="lightgray") | |
plt.savefig(f"inverse.png") | |
x_filt_avg_spec = self.get_average_spectrum(x_filt) | |
sm_x_filt_avg_spec = self.smooth_spectrum(x_filt_avg_spec) | |
y_avg_spec = self.get_average_spectrum(y) | |
sm_y_avg_spec = self.smooth_spectrum(y_avg_spec) | |
compare = torch.stack( | |
[ | |
torch.tensor(sm_in_spec), | |
torch.tensor(sm_x_filt_avg_spec), | |
torch.tensor(sm_ref_spec), | |
torch.tensor(sm_y_avg_spec), | |
] | |
) | |
self.plot_multi_spectrum( | |
compare, | |
legend=["in", "out", "target curve", "actual target"], | |
filename="outSpec", | |
) | |
return x_filt | |
def analyze_speech_dataset(self, filepaths, peak=-3.0): | |
avg_spec = [] | |
for filepath in tqdm(filepaths, ncols=80): | |
x, sr = torchaudio.load(filepath) | |
x /= x.abs().max() | |
x *= 10 ** (peak / 20.0) | |
avg_spec.append(self.get_average_spectrum(x)) | |
avg_specs = torch.stack(avg_spec) | |
avg_spec = avg_specs.mean(dim=0).numpy() | |
avg_spec_std = avg_specs.std(dim=0).numpy() | |
# self.plot_multi_spectrum(avg_specs, filename="allTargetEQs") | |
# self.plot_spectrum_stats(avg_spec, avg_spec_std, filename="targetEQstats") | |
sm_avg_spec = self.smooth_spectrum(avg_spec) | |
return avg_spec, sm_avg_spec | |
def smooth_spectrum(self, H): | |
# apply Savgol filter for smoothed target curve | |
return scipy.signal.savgol_filter(H, 1025, 2) | |
def get_average_spectrum(self, x): | |
# x = x[:, : self.n_fft] | |
X = torch.stft(x, self.n_fft, return_complex=True, normalized=True) | |
# fft_size = self.next_power_of_2(x.shape[-1]) | |
# X = torch.fft.rfft(x, n=fft_size) | |
X = X.abs() # convert to magnitude | |
X = X.mean(dim=-1).view(-1) # average across frames | |
return X | |
def next_power_of_2(x): | |
return 1 if x == 0 else int(2 ** np.ceil(np.log2(x))) | |
def plot_multi_spectrum(self, Hs, legend=[], filename=None): | |
bin_width = (self.sample_rate / 2) / (self.n_fft // 2) | |
freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width) | |
fig, ax1 = plt.subplots() | |
for H in Hs: | |
ax1.plot( | |
freqs, | |
20 * np.log10(abs(H) + 1e-8), | |
) | |
plt.legend(legend) | |
# avg_spec = Hs.mean(dim=0).numpy() | |
# ax1.plot(freqs, 20 * np.log10(avg_spec), color="k", linewidth=2) | |
ax1.set_xscale("log") | |
ax1.set_ylim([-80, 0]) | |
plt.grid(c="lightgray") | |
if filename is not None: | |
plt.savefig(f"{filename}.png") | |
def plot_spectrum_stats(self, H_mean, H_std, filename=None): | |
bin_width = (self.sample_rate / 2) / (self.n_fft // 2) | |
freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width) | |
fig, ax1 = plt.subplots() | |
ax1.plot(freqs, 20 * np.log10(H_mean)) | |
ax1.plot( | |
freqs, | |
(20 * np.log10(H_mean)) + (20 * np.log10(H_std)), | |
linestyle="--", | |
color="k", | |
) | |
ax1.plot( | |
freqs, | |
(20 * np.log10(H_mean)) - (20 * np.log10(H_std)), | |
linestyle="--", | |
color="k", | |
) | |
ax1.set_xscale("log") | |
ax1.set_ylim([-80, 0]) | |
plt.grid(c="lightgray") | |
if filename is not None: | |
plt.savefig(f"{filename}.png") | |
def plot_spectrum(self, H, legend=[], filename=None): | |
bin_width = (self.sample_rate / 2) / (self.n_fft // 2) | |
freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width) | |
fig, ax1 = plt.subplots() | |
ax1.plot(freqs, 20 * np.log10(H)) | |
ax1.set_xscale("log") | |
ax1.set_ylim([-80, 0]) | |
plt.grid(c="lightgray") | |
plt.legend(legend) | |
if filename is not None: | |
plt.savefig(f"{filename}.png") | |
class BaslineComp(torch.nn.Module): | |
def __init__( | |
self, | |
sample_rate: float = 44100, | |
): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.meter = pyln.Meter(sample_rate) | |
def forward(self, x, y): | |
x_lufs = self.meter.integrated_loudness(x.view(-1).numpy()) | |
y_lufs = self.meter.integrated_loudness(y.view(-1).numpy()) | |
delta_lufs = y_lufs - x_lufs | |
threshold = 0.0 | |
x_comp = x | |
x_comp_new = x | |
while delta_lufs > 0.5 and threshold > -80.0: | |
x_comp = x_comp_new # use the last setting | |
x_comp_new = compressor( | |
x.view(-1).numpy(), | |
self.sample_rate, | |
threshold=threshold, | |
ratio=3, | |
attack_time=0.001, | |
release_time=0.05, | |
knee_dB=6.0, | |
makeup_gain_dB=0.0, | |
) | |
x_comp_new = torch.tensor(x_comp_new) | |
x_comp_new /= x_comp_new.abs().max() | |
x_comp_new *= 10 ** (-12.0 / 20) | |
x_lufs = self.meter.integrated_loudness(x_comp_new.view(-1).numpy()) | |
delta_lufs = y_lufs - x_lufs | |
threshold -= 0.5 | |
return x_comp.view(1, 1, -1) | |
class BaselineEQAndComp(torch.nn.Module): | |
def __init__( | |
self, | |
ntaps=63, | |
n_fft=65536, | |
sample_rate=44100, | |
block_size=1024, | |
plugin_config=None, | |
): | |
super().__init__() | |
self.eq = BaselineEQ(ntaps, n_fft, sample_rate) | |
self.comp = BaslineComp(sample_rate) | |
def forward(self, x, y): | |
with torch.inference_mode(): | |
x /= x.abs().max() | |
y /= y.abs().max() | |
x *= 10 ** (-12.0 / 20) | |
y *= 10 ** (-12.0 / 20) | |
x = self.eq(x, y) | |
x /= x.abs().max() | |
y /= y.abs().max() | |
x *= 10 ** (-12.0 / 20) | |
y *= 10 ** (-12.0 / 20) | |
x = self.comp(x, y) | |
x /= x.abs().max() | |
x *= 10 ** (-12.0 / 20) | |
return x | |