ITO-Master / networks /dasp_additionals.py
jhtonyKoo's picture
revise
6fc042a
raw
history blame
17.7 kB
"""
Implementation of differentiable mastering effects based on DASP-pytorch and torchcomp libraries
- Distortion
- Multiband Compressor
- Limiter
DASP-pytorch: https://github.com/csteinmetz1/dasp-pytorch
torchcomp: https://github.com/yoyololicon/torchcomp
"""
import dasp_pytorch
from dasp_pytorch.modules import Processor
import torchcomp
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import time
EPS = 1e-6
class Distortion(Processor):
def __init__(
self,
sample_rate: int,
min_gain_db: float = 0.0,
max_gain_db: float = 24.0,
):
super().__init__()
self.sample_rate = sample_rate
self.process_fn = distortion
self.param_ranges = {
"drive_db": (min_gain_db, max_gain_db),
"parallel_weight_factor": (0.2, 0.7),
}
self.num_params = len(self.param_ranges)
def distortion(x: torch.Tensor,
sample_rate: int,
drive_db: torch.Tensor,
parallel_weight_factor: torch.Tensor()):
"""Simple soft-clipping distortion with drive control.
Args:
x (torch.Tensor): Input audio tensor with shape (bs, chs, seq_len)
sample_rate (int): Audio sample rate.
drive_db (torch.Tensor): Drive in dB with shape (bs)
Returns:
torch.Tensor: Output audio tensor with shape (bs, chs, seq_len)
"""
bs, chs, seq_len = x.size()
parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1)
# return torch.tanh(x * (10 ** (drive_db.view(bs, chs, -1) / 20.0))) -> wrong?
x_dist = torch.tanh(x * (10 ** (drive_db.view(bs, 1, 1) / 20.0)))
# parallel compuatation
return parallel_weight_factor * x_dist + (1-parallel_weight_factor) * x
class Multiband_Compressor(Processor):
def __init__(
self,
sample_rate: int,
min_threshold_db_comp: float = -60.0,
max_threshold_db_comp: float = 0.0-EPS,
min_ratio_comp: float = 1.0+EPS,
max_ratio_comp: float = 20.0,
min_attack_ms_comp: float = 5.0,
max_attack_ms_comp: float = 100.0,
min_release_ms_comp: float = 5.0,
max_release_ms_comp: float = 100.0,
min_threshold_db_exp: float = -60.0,
max_threshold_db_exp: float = 0.0-EPS,
min_ratio_exp: float = 0.0+EPS,
max_ratio_exp: float = 1.0-EPS,
min_attack_ms_exp: float = 5.0,
max_attack_ms_exp: float = 100.0,
min_release_ms_exp: float = 5.0,
max_release_ms_exp: float = 100.0,
):
super().__init__()
self.sample_rate = sample_rate
self.process_fn = multiband_compressor
self.param_ranges = {
"low_cutoff": (20, 300),
"high_cutoff": (2000, 12000),
"parallel_weight_factor": (0.2, 0.7),
"low_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
"low_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp),
"low_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
"low_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp),
"low_shelf_at": (min_attack_ms_exp, max_attack_ms_exp),
"low_shelf_rt": (min_release_ms_exp, max_release_ms_exp),
"mid_band_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
"mid_band_comp_ratio": (min_ratio_comp, max_ratio_comp),
"mid_band_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
"mid_band_exp_ratio": (min_ratio_exp, max_ratio_exp),
"mid_band_at": (min_attack_ms_exp, max_attack_ms_exp),
"mid_band_rt": (min_release_ms_exp, max_release_ms_exp),
"high_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
"high_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp),
"high_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
"high_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp),
"high_shelf_at": (min_attack_ms_exp, max_attack_ms_exp),
"high_shelf_rt": (min_release_ms_exp, max_release_ms_exp),
}
self.num_params = len(self.param_ranges)
def linkwitz_riley_4th_order(
x: torch.Tensor,
cutoff_freq: torch.Tensor,
sample_rate: float,
filter_type: str):
q_factor = torch.ones(cutoff_freq.shape) / torch.sqrt(torch.tensor([2.0]))
gain_db = torch.zeros(cutoff_freq.shape)
q_factor = q_factor.to(x.device)
gain_db = gain_db.to(x.device)
b, a = dasp_pytorch.signal.biquad(
gain_db,
cutoff_freq,
q_factor,
sample_rate,
filter_type
)
del gain_db
del q_factor
eff_bs = x.size(0)
# six second order sections
sos = torch.cat((b, a), dim=-1).unsqueeze(1)
# apply filter twice to phase difference amounts of 360°
x = dasp_pytorch.signal.sosfilt_via_fsm(sos, x)
x_out = dasp_pytorch.signal.sosfilt_via_fsm(sos, x)
return x_out
def multiband_compressor(
x: torch.Tensor,
sample_rate: float,
low_cutoff: torch.Tensor,
high_cutoff: torch.Tensor,
parallel_weight_factor: torch.Tensor,
low_shelf_comp_thresh: torch.Tensor,
low_shelf_comp_ratio: torch.Tensor,
low_shelf_exp_thresh: torch.Tensor,
low_shelf_exp_ratio: torch.Tensor,
low_shelf_at: torch.Tensor,
low_shelf_rt: torch.Tensor,
mid_band_comp_thresh: torch.Tensor,
mid_band_comp_ratio: torch.Tensor,
mid_band_exp_thresh: torch.Tensor,
mid_band_exp_ratio: torch.Tensor,
mid_band_at: torch.Tensor,
mid_band_rt: torch.Tensor,
high_shelf_comp_thresh: torch.Tensor,
high_shelf_comp_ratio: torch.Tensor,
high_shelf_exp_thresh: torch.Tensor,
high_shelf_exp_ratio: torch.Tensor,
high_shelf_at: torch.Tensor,
high_shelf_rt: torch.Tensor,
):
"""Multiband (Three-band) Compressor.
Low-shelf -> Mid-band -> High-shelf
Args:
x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len)
sample_rate (float): Audio sample rate.
low_cutoff (torch.Tensor): Low-shelf filter cutoff frequency in Hz.
high_cutoff (torch.Tensor): High-shelf filter cutoff frequency in Hz.
low_shelf_comp_thresh (torch.Tensor):
low_shelf_comp_ratio (torch.Tensor):
low_shelf_exp_thresh (torch.Tensor):
low_shelf_exp_ratio (torch.Tensor):
low_shelf_at (torch.Tensor):
low_shelf_rt (torch.Tensor):
mid_band_comp_thresh (torch.Tensor):
mid_band_comp_ratio (torch.Tensor):
mid_band_exp_thresh (torch.Tensor):
mid_band_exp_ratio (torch.Tensor):
mid_band_at (torch.Tensor):
mid_band_rt (torch.Tensor):
high_shelf_comp_thresh (torch.Tensor):
high_shelf_comp_ratio (torch.Tensor):
high_shelf_exp_thresh (torch.Tensor):
high_shelf_exp_ratio (torch.Tensor):
high_shelf_at (torch.Tensor):
high_shelf_rt (torch.Tensor):
Returns:
y (torch.Tensor): Filtered signal.
"""
bs, chs, seq_len = x.size()
low_cutoff = low_cutoff.view(-1, 1, 1)
high_cutoff = high_cutoff.view(-1, 1, 1)
parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1)
eff_bs = x.size(0)
''' cross over filter '''
# Low-shelf band (low frequencies)
low_band = linkwitz_riley_4th_order(x, low_cutoff, sample_rate, filter_type="low_pass")
# High-shelf band (high frequencies)
high_band = linkwitz_riley_4th_order(x, high_cutoff, sample_rate, filter_type="high_pass")
# Mid-band (band-pass)
mid_band = x - low_band - high_band # Subtract low and high bands from original signal
''' compressor '''
try:
x_out_low = low_band * torchcomp.compexp_gain(low_band.sum(axis=1).abs(),
comp_thresh=low_shelf_comp_thresh, \
comp_ratio=low_shelf_comp_ratio, \
exp_thresh=low_shelf_exp_thresh, \
exp_ratio=low_shelf_exp_ratio, \
at=torchcomp.ms2coef(low_shelf_at, sample_rate), \
rt=torchcomp.ms2coef(low_shelf_rt, sample_rate)).unsqueeze(1)
except:
x_out_low = low_band
print('\t!!!failed computing low-band compression!!!')
try:
x_out_high = high_band * torchcomp.compexp_gain(high_band.sum(axis=1).abs(),
comp_thresh=high_shelf_comp_thresh, \
comp_ratio=high_shelf_comp_ratio, \
exp_thresh=high_shelf_exp_thresh, \
exp_ratio=high_shelf_exp_ratio, \
at=torchcomp.ms2coef(high_shelf_at, sample_rate), \
rt=torchcomp.ms2coef(high_shelf_rt, sample_rate)).unsqueeze(1)
except:
x_out_high = high_band
print('\t!!!failed computing high-band compression!!!')
try:
x_out_mid = mid_band * torchcomp.compexp_gain(mid_band.sum(axis=1).abs(),
comp_thresh=mid_band_comp_thresh, \
comp_ratio=mid_band_comp_ratio, \
exp_thresh=mid_band_exp_thresh, \
exp_ratio=mid_band_exp_ratio, \
at=torchcomp.ms2coef(mid_band_at, sample_rate), \
rt=torchcomp.ms2coef(mid_band_rt, sample_rate)).unsqueeze(1)
except:
x_out_mid = mid_band
print('\t!!!failed computing mid-band compression!!!')
x_out = x_out_low + x_out_high + x_out_mid
# parallel computation
x_out = parallel_weight_factor * x_out + (1-parallel_weight_factor) * x
# move channels back
x_out = x_out.view(bs, chs, seq_len)
return x_out
class Limiter(Processor):
def __init__(
self,
sample_rate: int,
min_threshold_db: float = -60.0,
max_threshold_db: float = 0.0-EPS,
min_attack_ms: float = 5.0,
max_attack_ms: float = 100.0,
min_release_ms: float = 5.0,
max_release_ms: float = 100.0,
):
super().__init__()
self.sample_rate = sample_rate
self.process_fn = limiter
self.param_ranges = {
"threshold": (min_threshold_db, max_threshold_db),
"at": (min_attack_ms, max_attack_ms),
"rt": (min_release_ms, max_release_ms),
}
self.num_params = len(self.param_ranges)
def limiter(
x: torch.Tensor,
sample_rate: float,
threshold: float,
at: float,
rt: float,
):
"""Limiter.
from Chin-yun's paper
Args:
x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len)
sample_rate (float): Audio sample rate.
threshold (torch.Tensor): Limiter threshold in dB.
at (torch.Tensor): Attack time.
rt (torch.Tensor): Release time.
Returns:
y (torch.Tensor): Limited signal.
"""
bs, chs, seq_len = x.size()
x_out = x * torchcomp.limiter_gain(x.sum(axis=1).abs(),
threshold=threshold,
at=torchcomp.ms2coef(at, sample_rate),
rt=torchcomp.ms2coef(rt, sample_rate)).unsqueeze(1)
# move channels back
x_out = x_out.view(bs, chs, seq_len)
return x_out
class Random_Augmentation_Dasp(nn.Module):
def __init__(self, sample_rate, \
tgt_fx_names = ['eq', 'comp', 'imager', 'gain']):
super(Random_Augmentation_Dasp, self).__init__()
self.sample_rate = sample_rate
self.tgt_fx_names = tgt_fx_names
self.device = torch.device("cpu")
if torch.cuda.is_available():
self.device = torch.device(f"cuda")
self.fx_prob = {'eq': 0.9, \
'distortion': 0.3, \
'comp': 0.8, \
'multiband_comp': 0.8, \
'gain': 0.85, \
'imager': 0.6, \
'limiter': 1.0}
self.fx_processors = {}
for cur_fx in tgt_fx_names:
if cur_fx=='eq':
cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=sample_rate, \
min_gain_db = -10.0, \
max_gain_db = 10.0, \
min_q_factor = 0.5, \
max_q_factor=5.0)
elif cur_fx=='distortion':
cur_fx_module = Distortion(sample_rate=sample_rate,
min_gain_db = 0.0,
max_gain_db = 4.0)
elif cur_fx=='comp':
cur_fx_module = dasp_pytorch.Compressor(sample_rate=sample_rate)
elif cur_fx=='multiband_comp':
cur_fx_module = Multiband_Compressor(sample_rate=sample_rate,
min_threshold_db_comp = -30.0,
max_threshold_db_comp = -5.0,
min_ratio_comp = 1.5,
max_ratio_comp = 6.0,
min_attack_ms_comp = 1.0,
max_attack_ms_comp = 20.0,
min_release_ms_comp = 20.0,
max_release_ms_comp = 500.0,
min_threshold_db_exp = -30.0,
max_threshold_db_exp = -5.0,
min_ratio_exp = 0.0+EPS,
max_ratio_exp = 1.0-EPS,
min_attack_ms_exp = 1.0,
max_attack_ms_exp = 20.0,
min_release_ms_exp = 20.0,
max_release_ms_exp = 500.0,
)
elif cur_fx=='gain':
cur_fx_module = dasp_pytorch.Gain(sample_rate=sample_rate,
min_gain_db = 0.0,
max_gain_db = 6.0,)
elif cur_fx=='imager':
continue
elif cur_fx=='limiter':
cur_fx_module = Limiter(sample_rate=sample_rate,
min_threshold_db = -20.0,
max_threshold_db = 0.0-EPS,
min_attack_ms = 0.1,
max_attack_ms = 5.0,
min_release_ms = 20.0,
max_release_ms = 1000.0,)
else:
raise AssertionError(f"current fx name ({cur_fx}) not found")
self.fx_processors[cur_fx] = cur_fx_module
total_num_param = sum([self.fx_processors[cur_fx].num_params for cur_fx in self.fx_processors])
if 'imager' in tgt_fx_names:
total_num_param += 1
self.total_num_param = total_num_param
# network forward operation
def forward(self, x, rand_param=None, use_mask=None):
if rand_param==None:
rand_param = torch.rand((x.shape[0], self.total_num_param)).to(self.device)
else:
assert rand_param.shape[0]==x.shape[0] and rand_param.shape[1]==self.total_num_param
if use_mask==None:
use_mask = self.random_mask_generator(x.shape[0])
# dafx chain
cur_param_idx = 0
for cur_fx in self.tgt_fx_names:
cur_param_count = 1 if cur_fx=='imager' else self.fx_processors[cur_fx].num_params
if cur_fx=='imager':
x_processed = dasp_pytorch.functional.stereo_widener(x, \
sample_rate=self.sample_rate, \
width=rand_param[:,cur_param_idx:cur_param_idx+1])
else:
cur_input_param = rand_param[:, cur_param_idx:cur_param_idx+cur_param_count]
x_processed = self.fx_processors[cur_fx].process_normalized(x, cur_input_param)
# process all FX but decide to use the processed output based on probability
cur_mask = use_mask[cur_fx]
x = x_processed*cur_mask + x*~cur_mask
# update param index
cur_param_idx += cur_param_count
return x
def random_mask_generator(self, batch_size, repeat=1):
mask = {}
for cur_fx in self.tgt_fx_names:
mask[cur_fx] = self.fx_prob[cur_fx] > torch.rand(batch_size).view(-1, 1, 1)
if repeat>1:
mask[cur_fx] = mask[cur_fx].repeat(repeat, 1, 1)
mask[cur_fx] = mask[cur_fx].to(self.device)
return mask