ITO-Master / networks /architectures.py
jhtonyKoo's picture
revise
6fc042a
raw
history blame
17.7 kB
"""
Implementation of neural networks used in the task 'Music Mastering Style Transfer'
- 'Effects Encoder'
- 'Mastering Style Transfer'
- 'Differentiable Mastering Style Transfer'
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import dasp_pytorch
import os
import sys
import time
currentdir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(currentdir)
from network_utils import *
from dasp_additionals import Multiband_Compressor, Distortion, Limiter
# compute receptive field
def compute_receptive_field(kernels, strides, dilations):
rf = 0
for i in range(len(kernels)):
rf += rf * strides[i] + (kernels[i]-strides[i]) * dilations[i]
return rf
# Encoder of music effects for contrastive learning of music effects
class Effects_Encoder(nn.Module):
def __init__(self, config):
super(Effects_Encoder, self).__init__()
# input is stereo channeled audio
config["channels"].insert(0, 2)
# encoder layers
encoder = []
for i in range(len(config["kernels"])):
if config["conv_block"]=='res':
encoder.append(Res_ConvBlock(dimension=1, \
in_channels=config["channels"][i], \
out_channels=config["channels"][i+1], \
kernel_size=config["kernels"][i], \
stride=config["strides"][i], \
padding="SAME", \
dilation=config["dilation"][i], \
norm=config["norm"], \
activation=config["activation"], \
last_activation=config["activation"]))
elif config["conv_block"]=='conv':
encoder.append(ConvBlock(dimension=1, \
layer_num=1, \
in_channels=config["channels"][i], \
out_channels=config["channels"][i+1], \
kernel_size=config["kernels"][i], \
stride=config["strides"][i], \
padding="VALID", \
dilation=config["dilation"][i], \
norm=config["norm"], \
activation=config["activation"], \
last_activation=config["activation"], \
mode='conv'))
self.encoder = nn.Sequential(*encoder)
# pooling method
self.glob_pool = nn.AdaptiveAvgPool1d(1)
# network forward operation
def forward(self, input):
enc_output = self.encoder(input)
glob_pooled = self.glob_pool(enc_output).squeeze(-1)
# outputs c feature
return glob_pooled
class TCNBlock(torch.nn.Module):
def __init__(self,
in_ch,
out_ch,
kernel_size=3,
stride=1,
dilation=1,
cond_dim=2048,
grouped=False,
causal=False,
conditional=False,
**kwargs):
super(TCNBlock, self).__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.dilation = dilation
self.grouped = grouped
self.causal = causal
self.conditional = conditional
groups = out_ch if grouped and (in_ch % out_ch == 0) else 1
self.pad_length = ((kernel_size-1)*dilation) if self.causal else ((kernel_size-1)*dilation)//2
self.conv1 = torch.nn.Conv1d(in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=self.pad_length,
dilation=dilation,
groups=groups,
bias=False)
if grouped:
self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)
if conditional:
self.film = FiLM(cond_dim, out_ch)
self.bn = torch.nn.BatchNorm1d(out_ch)
self.relu = torch.nn.LeakyReLU()
self.res = torch.nn.Conv1d(in_ch,
out_ch,
kernel_size=1,
stride=stride,
groups=in_ch,
bias=False)
def forward(self, x, p):
x_in = x
x = self.relu(self.bn(self.conv1(x)))
x = self.film(x, p)
x_res = self.res(x_in)
if self.causal:
x = x[..., :-self.pad_length]
x += x_res
return x
import pytorch_lightning as pl
class TCNModel(pl.LightningModule):
""" Temporal convolutional network with conditioning module.
Args:
nparams (int): Number of conditioning parameters.
ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1
noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1
nblocks (int): Number of total TCN blocks. Default: 10
kernel_size (int): Width of the convolutional kernels. Default: 3
dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2
channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10
grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False
causal (bool): Causal TCN configuration does not consider future input values. Default: False
skip_connections (bool): Skip connections from each block to the output. Default: False
num_examples (int): Number of evaluation audio examples to log after each epochs. Default: 4
"""
def __init__(self,
nparams,
ninputs=1,
noutputs=1,
nblocks=10,
kernel_size=3,
stride=1,
dilation_growth=1,
channel_growth=1,
channel_width=32,
stack_size=10,
cond_dim=2048,
grouped=False,
causal=False,
skip_connections=False,
num_examples=4,
save_dir=None,
**kwargs):
super(TCNModel, self).__init__()
self.save_hyperparameters()
self.blocks = torch.nn.ModuleList()
for n in range(nblocks):
in_ch = out_ch if n > 0 else ninputs
if self.hparams.channel_growth > 1:
out_ch = in_ch * self.hparams.channel_growth
else:
out_ch = self.hparams.channel_width
dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
cur_stride = stride[n] if isinstance(stride, list) else stride
self.blocks.append(TCNBlock(in_ch,
out_ch,
kernel_size=self.hparams.kernel_size,
stride=cur_stride,
dilation=dilation,
padding="same" if self.hparams.causal else "valid",
causal=self.hparams.causal,
cond_dim=cond_dim,
grouped=self.hparams.grouped,
conditional=True if self.hparams.nparams > 0 else False))
self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)
def forward(self, x, cond):
# iterate over blocks passing conditioning
for idx, block in enumerate(self.blocks):
# for SeFa
if isinstance(cond, list):
x = block(x, cond[idx])
else:
x = block(x, cond)
skips = 0
# out = torch.tanh(self.output(x + skips))
out = torch.clamp(self.output(x + skips), min=-1, max=1)
return out
def compute_receptive_field(self):
""" Compute the receptive field in samples."""
rf = self.hparams.kernel_size
for n in range(1,self.hparams.nblocks):
dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
rf = rf + ((self.hparams.kernel_size-1) * dilation)
return rf
# add any model hyperparameters here
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
# --- model related ---
parser.add_argument('--ninputs', type=int, default=1)
parser.add_argument('--noutputs', type=int, default=1)
parser.add_argument('--nblocks', type=int, default=4)
parser.add_argument('--kernel_size', type=int, default=5)
parser.add_argument('--dilation_growth', type=int, default=10)
parser.add_argument('--channel_growth', type=int, default=1)
parser.add_argument('--channel_width', type=int, default=32)
parser.add_argument('--stack_size', type=int, default=10)
parser.add_argument('--grouped', default=False, action='store_true')
parser.add_argument('--causal', default=False, action="store_true")
parser.add_argument('--skip_connections', default=False, action="store_true")
return parser
# Module for fitting SeFa parameters
class Dasp_Mastering_Style_Transfer(nn.Module):
def __init__(self, num_features, sample_rate, \
tgt_fx_names = ['eq', 'comp', 'imager', 'gain'], \
model_type='2mlp', \
config=None, \
batch_size=4):
super(Dasp_Mastering_Style_Transfer, self).__init__()
self.sample_rate = sample_rate
self.tgt_fx_names = tgt_fx_names
self.fx_processors = {}
self.last_predicted_params = None
for cur_fx in tgt_fx_names:
if cur_fx=='eq':
cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=sample_rate, \
min_gain_db = -20.0, \
max_gain_db = 20.0, \
min_q_factor = 0.1, \
max_q_factor=5.0)
elif cur_fx=='distortion':
cur_fx_module = Distortion(sample_rate=sample_rate,
min_gain_db = 0.0,
max_gain_db = 8.0)
elif cur_fx=='comp':
cur_fx_module = dasp_pytorch.Compressor(sample_rate=sample_rate)
elif cur_fx=='multiband_comp':
cur_fx_module = Multiband_Compressor(sample_rate=sample_rate)
elif cur_fx=='gain':
cur_fx_module = dasp_pytorch.Gain(sample_rate=sample_rate)
elif cur_fx=='imager':
continue
elif cur_fx=='limiter':
cur_fx_module = Limiter(sample_rate=sample_rate)
else:
raise AssertionError(f"current fx name ({cur_fx}) not found")
self.fx_processors[cur_fx] = cur_fx_module
total_num_param = sum([self.fx_processors[cur_fx].num_params for cur_fx in self.fx_processors])
if 'imager' in tgt_fx_names:
total_num_param += 1
''' model architecture '''
self.model_type = model_type
if self.model_type.lower()=='tcn':
self.network = TCNModel(nparams=config["condition_dimension"], ninputs=2, \
noutputs=total_num_param, \
nblocks=config["nblocks"], \
dilation_growth=config["dilation_growth"], \
kernel_size=config["kernel_size"], \
stride=config['stride'], \
channel_width=config["channel_width"], \
stack_size=config["stack_size"], \
cond_dim=config["condition_dimension"], \
causal=config["causal"])
elif self.model_type.lower()=='ito':
self.params = torch.nn.Parameter(torch.ones((batch_size,total_num_param))*0.5)
# network forward operation
def forward(self, x, embedding):
# embedding mapper
if self.model_type.lower()=='tcn':
est_param = self.network(x, embedding)
est_param = est_param.mean(axis=-1)
elif self.model_type.lower()=='ito':
est_param = self.params
est_param = torch.clamp(est_param, min=0.0, max=1.0)
if self.model_type.lower()!='ito':
est_param = F.sigmoid(est_param)
self.last_predicted_params = est_param
# dafx chain
cur_param_idx = 0
for cur_fx in self.tgt_fx_names:
if cur_fx=='imager':
cur_param_count = 1
x = dasp_pytorch.functional.stereo_widener(x, \
sample_rate=self.sample_rate, \
width=est_param[:,cur_param_idx:cur_param_idx+1])
else:
cur_param_count = self.fx_processors[cur_fx].num_params
cur_input_param = est_param[:, cur_param_idx:cur_param_idx+cur_param_count]
x = self.fx_processors[cur_fx].process_normalized(x, cur_input_param)
# update param index
cur_param_idx += cur_param_count
return x
def reset_fx_chain(self, ):
self.fx_processors = {}
for cur_fx in self.tgt_fx_names:
if cur_fx=='eq':
cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=self.sample_rate, \
min_gain_db = -20.0, \
max_gain_db = 20.0, \
min_q_factor = 0.1, \
max_q_factor=5.0)
elif cur_fx=='distortion':
cur_fx_module = Distortion(sample_rate=self.sample_rate,
min_gain_db = 0.0,
max_gain_db = 8.0)
elif cur_fx=='comp':
cur_fx_module = dasp_pytorch.Compressor(sample_rate=self.sample_rate)
elif cur_fx=='multiband_comp':
cur_fx_module = Multiband_Compressor(sample_rate=self.sample_rate)
elif cur_fx=='gain':
cur_fx_module = dasp_pytorch.Gain(sample_rate=self.sample_rate)
elif cur_fx=='imager':
continue
elif cur_fx=='limiter':
cur_fx_module = Limiter(sample_rate=self.sample_rate)
else:
raise AssertionError(f"current fx name ({cur_fx}) not found")
self.fx_processors[cur_fx] = cur_fx_module
def get_last_predicted_params(self):
if self.last_predicted_params is None:
return None
params_dict = {}
cur_param_idx = 0
for cur_fx in self.tgt_fx_names:
if cur_fx == 'imager':
cur_param_count = 1
normalized_param = self.last_predicted_params[:, cur_param_idx:cur_param_idx+1]
original_param = self.denormalize_param(normalized_param, 0, 1)
params_dict[cur_fx] = original_param
else:
cur_param_count = self.fx_processors[cur_fx].num_params
normalized_params = self.last_predicted_params[:, cur_param_idx:cur_param_idx+cur_param_count]
original_params = self.denormalize_params(cur_fx, normalized_params)
params_dict[cur_fx] = original_params
cur_param_idx += cur_param_count
return params_dict
def denormalize_params(self, fx_name, normalized_params):
fx_processor = self.fx_processors[fx_name]
original_params = {}
for i, (param_name, (min_val, max_val)) in enumerate(fx_processor.param_ranges.items()):
original_param = self.denormalize_param(normalized_params[:, i:i+1], min_val, max_val)
original_params[param_name] = original_param
return original_params
@staticmethod
def denormalize_param(normalized_param, min_val, max_val):
return normalized_param * (max_val - min_val) + min_val