Spaces:
Running
Running
""" | |
Implementation of differentiable mastering effects based on DASP-pytorch and torchcomp libraries | |
- Distortion | |
- Multiband Compressor | |
- Limiter | |
DASP-pytorch: https://github.com/csteinmetz1/dasp-pytorch | |
torchcomp: https://github.com/yoyololicon/torchcomp | |
""" | |
import dasp_pytorch | |
from dasp_pytorch.modules import Processor | |
import torchcomp | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import numpy as np | |
import time | |
EPS = 1e-6 | |
class Distortion(Processor): | |
def __init__( | |
self, | |
sample_rate: int, | |
min_gain_db: float = 0.0, | |
max_gain_db: float = 24.0, | |
): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.process_fn = distortion | |
self.param_ranges = { | |
"drive_db": (min_gain_db, max_gain_db), | |
"parallel_weight_factor": (0.2, 0.7), | |
} | |
self.num_params = len(self.param_ranges) | |
def distortion(x: torch.Tensor, | |
sample_rate: int, | |
drive_db: torch.Tensor, | |
parallel_weight_factor: torch.Tensor()): | |
"""Simple soft-clipping distortion with drive control. | |
Args: | |
x (torch.Tensor): Input audio tensor with shape (bs, chs, seq_len) | |
sample_rate (int): Audio sample rate. | |
drive_db (torch.Tensor): Drive in dB with shape (bs) | |
Returns: | |
torch.Tensor: Output audio tensor with shape (bs, chs, seq_len) | |
""" | |
bs, chs, seq_len = x.size() | |
parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1) | |
# return torch.tanh(x * (10 ** (drive_db.view(bs, chs, -1) / 20.0))) -> wrong? | |
x_dist = torch.tanh(x * (10 ** (drive_db.view(bs, 1, 1) / 20.0))) | |
# parallel compuatation | |
return parallel_weight_factor * x_dist + (1-parallel_weight_factor) * x | |
class Multiband_Compressor(Processor): | |
def __init__( | |
self, | |
sample_rate: int, | |
min_threshold_db_comp: float = -60.0, | |
max_threshold_db_comp: float = 0.0-EPS, | |
min_ratio_comp: float = 1.0+EPS, | |
max_ratio_comp: float = 20.0, | |
min_attack_ms_comp: float = 5.0, | |
max_attack_ms_comp: float = 100.0, | |
min_release_ms_comp: float = 5.0, | |
max_release_ms_comp: float = 100.0, | |
min_threshold_db_exp: float = -60.0, | |
max_threshold_db_exp: float = 0.0-EPS, | |
min_ratio_exp: float = 0.0+EPS, | |
max_ratio_exp: float = 1.0-EPS, | |
min_attack_ms_exp: float = 5.0, | |
max_attack_ms_exp: float = 100.0, | |
min_release_ms_exp: float = 5.0, | |
max_release_ms_exp: float = 100.0, | |
): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.process_fn = multiband_compressor | |
self.param_ranges = { | |
"low_cutoff": (20, 300), | |
"high_cutoff": (2000, 12000), | |
"parallel_weight_factor": (0.2, 0.7), | |
"low_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp), | |
"low_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp), | |
"low_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp), | |
"low_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp), | |
"low_shelf_at": (min_attack_ms_exp, max_attack_ms_exp), | |
"low_shelf_rt": (min_release_ms_exp, max_release_ms_exp), | |
"mid_band_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp), | |
"mid_band_comp_ratio": (min_ratio_comp, max_ratio_comp), | |
"mid_band_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp), | |
"mid_band_exp_ratio": (min_ratio_exp, max_ratio_exp), | |
"mid_band_at": (min_attack_ms_exp, max_attack_ms_exp), | |
"mid_band_rt": (min_release_ms_exp, max_release_ms_exp), | |
"high_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp), | |
"high_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp), | |
"high_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp), | |
"high_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp), | |
"high_shelf_at": (min_attack_ms_exp, max_attack_ms_exp), | |
"high_shelf_rt": (min_release_ms_exp, max_release_ms_exp), | |
} | |
self.num_params = len(self.param_ranges) | |
def linkwitz_riley_4th_order( | |
x: torch.Tensor, | |
cutoff_freq: torch.Tensor, | |
sample_rate: float, | |
filter_type: str): | |
q_factor = torch.ones(cutoff_freq.shape) / torch.sqrt(torch.tensor([2.0])) | |
gain_db = torch.zeros(cutoff_freq.shape) | |
q_factor = q_factor.to(x.device) | |
gain_db = gain_db.to(x.device) | |
b, a = dasp_pytorch.signal.biquad( | |
gain_db, | |
cutoff_freq, | |
q_factor, | |
sample_rate, | |
filter_type | |
) | |
del gain_db | |
del q_factor | |
eff_bs = x.size(0) | |
# six second order sections | |
sos = torch.cat((b, a), dim=-1).unsqueeze(1) | |
# apply filter twice to phase difference amounts of 360° | |
x = dasp_pytorch.signal.sosfilt_via_fsm(sos, x) | |
x_out = dasp_pytorch.signal.sosfilt_via_fsm(sos, x) | |
return x_out | |
def multiband_compressor( | |
x: torch.Tensor, | |
sample_rate: float, | |
low_cutoff: torch.Tensor, | |
high_cutoff: torch.Tensor, | |
parallel_weight_factor: torch.Tensor, | |
low_shelf_comp_thresh: torch.Tensor, | |
low_shelf_comp_ratio: torch.Tensor, | |
low_shelf_exp_thresh: torch.Tensor, | |
low_shelf_exp_ratio: torch.Tensor, | |
low_shelf_at: torch.Tensor, | |
low_shelf_rt: torch.Tensor, | |
mid_band_comp_thresh: torch.Tensor, | |
mid_band_comp_ratio: torch.Tensor, | |
mid_band_exp_thresh: torch.Tensor, | |
mid_band_exp_ratio: torch.Tensor, | |
mid_band_at: torch.Tensor, | |
mid_band_rt: torch.Tensor, | |
high_shelf_comp_thresh: torch.Tensor, | |
high_shelf_comp_ratio: torch.Tensor, | |
high_shelf_exp_thresh: torch.Tensor, | |
high_shelf_exp_ratio: torch.Tensor, | |
high_shelf_at: torch.Tensor, | |
high_shelf_rt: torch.Tensor, | |
): | |
"""Multiband (Three-band) Compressor. | |
Low-shelf -> Mid-band -> High-shelf | |
Args: | |
x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len) | |
sample_rate (float): Audio sample rate. | |
low_cutoff (torch.Tensor): Low-shelf filter cutoff frequency in Hz. | |
high_cutoff (torch.Tensor): High-shelf filter cutoff frequency in Hz. | |
low_shelf_comp_thresh (torch.Tensor): | |
low_shelf_comp_ratio (torch.Tensor): | |
low_shelf_exp_thresh (torch.Tensor): | |
low_shelf_exp_ratio (torch.Tensor): | |
low_shelf_at (torch.Tensor): | |
low_shelf_rt (torch.Tensor): | |
mid_band_comp_thresh (torch.Tensor): | |
mid_band_comp_ratio (torch.Tensor): | |
mid_band_exp_thresh (torch.Tensor): | |
mid_band_exp_ratio (torch.Tensor): | |
mid_band_at (torch.Tensor): | |
mid_band_rt (torch.Tensor): | |
high_shelf_comp_thresh (torch.Tensor): | |
high_shelf_comp_ratio (torch.Tensor): | |
high_shelf_exp_thresh (torch.Tensor): | |
high_shelf_exp_ratio (torch.Tensor): | |
high_shelf_at (torch.Tensor): | |
high_shelf_rt (torch.Tensor): | |
Returns: | |
y (torch.Tensor): Filtered signal. | |
""" | |
bs, chs, seq_len = x.size() | |
low_cutoff = low_cutoff.view(-1, 1, 1) | |
high_cutoff = high_cutoff.view(-1, 1, 1) | |
parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1) | |
eff_bs = x.size(0) | |
''' cross over filter ''' | |
# Low-shelf band (low frequencies) | |
low_band = linkwitz_riley_4th_order(x, low_cutoff, sample_rate, filter_type="low_pass") | |
# High-shelf band (high frequencies) | |
high_band = linkwitz_riley_4th_order(x, high_cutoff, sample_rate, filter_type="high_pass") | |
# Mid-band (band-pass) | |
mid_band = x - low_band - high_band # Subtract low and high bands from original signal | |
''' compressor ''' | |
try: | |
x_out_low = low_band * torchcomp.compexp_gain(low_band.sum(axis=1).abs(), | |
comp_thresh=low_shelf_comp_thresh, \ | |
comp_ratio=low_shelf_comp_ratio, \ | |
exp_thresh=low_shelf_exp_thresh, \ | |
exp_ratio=low_shelf_exp_ratio, \ | |
at=torchcomp.ms2coef(low_shelf_at, sample_rate), \ | |
rt=torchcomp.ms2coef(low_shelf_rt, sample_rate)).unsqueeze(1) | |
except: | |
x_out_low = low_band | |
print('\t!!!failed computing low-band compression!!!') | |
try: | |
x_out_high = high_band * torchcomp.compexp_gain(high_band.sum(axis=1).abs(), | |
comp_thresh=high_shelf_comp_thresh, \ | |
comp_ratio=high_shelf_comp_ratio, \ | |
exp_thresh=high_shelf_exp_thresh, \ | |
exp_ratio=high_shelf_exp_ratio, \ | |
at=torchcomp.ms2coef(high_shelf_at, sample_rate), \ | |
rt=torchcomp.ms2coef(high_shelf_rt, sample_rate)).unsqueeze(1) | |
except: | |
x_out_high = high_band | |
print('\t!!!failed computing high-band compression!!!') | |
try: | |
x_out_mid = mid_band * torchcomp.compexp_gain(mid_band.sum(axis=1).abs(), | |
comp_thresh=mid_band_comp_thresh, \ | |
comp_ratio=mid_band_comp_ratio, \ | |
exp_thresh=mid_band_exp_thresh, \ | |
exp_ratio=mid_band_exp_ratio, \ | |
at=torchcomp.ms2coef(mid_band_at, sample_rate), \ | |
rt=torchcomp.ms2coef(mid_band_rt, sample_rate)).unsqueeze(1) | |
except: | |
x_out_mid = mid_band | |
print('\t!!!failed computing mid-band compression!!!') | |
x_out = x_out_low + x_out_high + x_out_mid | |
# parallel computation | |
x_out = parallel_weight_factor * x_out + (1-parallel_weight_factor) * x | |
# move channels back | |
x_out = x_out.view(bs, chs, seq_len) | |
return x_out | |
class Limiter(Processor): | |
def __init__( | |
self, | |
sample_rate: int, | |
min_threshold_db: float = -60.0, | |
max_threshold_db: float = 0.0-EPS, | |
min_attack_ms: float = 5.0, | |
max_attack_ms: float = 100.0, | |
min_release_ms: float = 5.0, | |
max_release_ms: float = 100.0, | |
): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.process_fn = limiter | |
self.param_ranges = { | |
"threshold": (min_threshold_db, max_threshold_db), | |
"at": (min_attack_ms, max_attack_ms), | |
"rt": (min_release_ms, max_release_ms), | |
} | |
self.num_params = len(self.param_ranges) | |
def limiter( | |
x: torch.Tensor, | |
sample_rate: float, | |
threshold: float, | |
at: float, | |
rt: float, | |
): | |
"""Limiter. | |
from Chin-yun's paper | |
Args: | |
x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len) | |
sample_rate (float): Audio sample rate. | |
threshold (torch.Tensor): Limiter threshold in dB. | |
at (torch.Tensor): Attack time. | |
rt (torch.Tensor): Release time. | |
Returns: | |
y (torch.Tensor): Limited signal. | |
""" | |
bs, chs, seq_len = x.size() | |
x_out = x * torchcomp.limiter_gain(x.sum(axis=1).abs(), | |
threshold=threshold, | |
at=torchcomp.ms2coef(at, sample_rate), | |
rt=torchcomp.ms2coef(rt, sample_rate)).unsqueeze(1) | |
# move channels back | |
x_out = x_out.view(bs, chs, seq_len) | |
return x_out | |
class Random_Augmentation_Dasp(nn.Module): | |
def __init__(self, sample_rate, \ | |
tgt_fx_names = ['eq', 'comp', 'imager', 'gain']): | |
super(Random_Augmentation_Dasp, self).__init__() | |
self.sample_rate = sample_rate | |
self.tgt_fx_names = tgt_fx_names | |
self.device = torch.device("cpu") | |
if torch.cuda.is_available(): | |
self.device = torch.device(f"cuda") | |
self.fx_prob = {'eq': 0.9, \ | |
'distortion': 0.3, \ | |
'comp': 0.8, \ | |
'multiband_comp': 0.8, \ | |
'gain': 0.85, \ | |
'imager': 0.6, \ | |
'limiter': 1.0} | |
self.fx_processors = {} | |
for cur_fx in tgt_fx_names: | |
if cur_fx=='eq': | |
cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=sample_rate, \ | |
min_gain_db = -10.0, \ | |
max_gain_db = 10.0, \ | |
min_q_factor = 0.5, \ | |
max_q_factor=5.0) | |
elif cur_fx=='distortion': | |
cur_fx_module = Distortion(sample_rate=sample_rate, | |
min_gain_db = 0.0, | |
max_gain_db = 4.0) | |
elif cur_fx=='comp': | |
cur_fx_module = dasp_pytorch.Compressor(sample_rate=sample_rate) | |
elif cur_fx=='multiband_comp': | |
cur_fx_module = Multiband_Compressor(sample_rate=sample_rate, | |
min_threshold_db_comp = -30.0, | |
max_threshold_db_comp = -5.0, | |
min_ratio_comp = 1.5, | |
max_ratio_comp = 6.0, | |
min_attack_ms_comp = 1.0, | |
max_attack_ms_comp = 20.0, | |
min_release_ms_comp = 20.0, | |
max_release_ms_comp = 500.0, | |
min_threshold_db_exp = -30.0, | |
max_threshold_db_exp = -5.0, | |
min_ratio_exp = 0.0+EPS, | |
max_ratio_exp = 1.0-EPS, | |
min_attack_ms_exp = 1.0, | |
max_attack_ms_exp = 20.0, | |
min_release_ms_exp = 20.0, | |
max_release_ms_exp = 500.0, | |
) | |
elif cur_fx=='gain': | |
cur_fx_module = dasp_pytorch.Gain(sample_rate=sample_rate, | |
min_gain_db = 0.0, | |
max_gain_db = 6.0,) | |
elif cur_fx=='imager': | |
continue | |
elif cur_fx=='limiter': | |
cur_fx_module = Limiter(sample_rate=sample_rate, | |
min_threshold_db = -20.0, | |
max_threshold_db = 0.0-EPS, | |
min_attack_ms = 0.1, | |
max_attack_ms = 5.0, | |
min_release_ms = 20.0, | |
max_release_ms = 1000.0,) | |
else: | |
raise AssertionError(f"current fx name ({cur_fx}) not found") | |
self.fx_processors[cur_fx] = cur_fx_module | |
total_num_param = sum([self.fx_processors[cur_fx].num_params for cur_fx in self.fx_processors]) | |
if 'imager' in tgt_fx_names: | |
total_num_param += 1 | |
self.total_num_param = total_num_param | |
# network forward operation | |
def forward(self, x, rand_param=None, use_mask=None): | |
if rand_param==None: | |
rand_param = torch.rand((x.shape[0], self.total_num_param)).to(self.device) | |
else: | |
assert rand_param.shape[0]==x.shape[0] and rand_param.shape[1]==self.total_num_param | |
if use_mask==None: | |
use_mask = self.random_mask_generator(x.shape[0]) | |
# dafx chain | |
cur_param_idx = 0 | |
for cur_fx in self.tgt_fx_names: | |
cur_param_count = 1 if cur_fx=='imager' else self.fx_processors[cur_fx].num_params | |
if cur_fx=='imager': | |
x_processed = dasp_pytorch.functional.stereo_widener(x, \ | |
sample_rate=self.sample_rate, \ | |
width=rand_param[:,cur_param_idx:cur_param_idx+1]) | |
else: | |
cur_input_param = rand_param[:, cur_param_idx:cur_param_idx+cur_param_count] | |
x_processed = self.fx_processors[cur_fx].process_normalized(x, cur_input_param) | |
# process all FX but decide to use the processed output based on probability | |
cur_mask = use_mask[cur_fx] | |
x = x_processed*cur_mask + x*~cur_mask | |
# update param index | |
cur_param_idx += cur_param_count | |
return x | |
def random_mask_generator(self, batch_size, repeat=1): | |
mask = {} | |
for cur_fx in self.tgt_fx_names: | |
mask[cur_fx] = self.fx_prob[cur_fx] > torch.rand(batch_size).view(-1, 1, 1) | |
if repeat>1: | |
mask[cur_fx] = mask[cur_fx].repeat(repeat, 1, 1) | |
mask[cur_fx] = mask[cur_fx].to(self.device) | |
return mask | |