Spaces:
Running
Running
import torch | |
import soundfile as sf | |
import numpy as np | |
import argparse | |
import os | |
import yaml | |
import sys | |
currentdir = os.path.dirname(os.path.realpath(__file__)) | |
sys.path.append(os.path.dirname(currentdir)) | |
from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder | |
from modules.loss import AudioFeatureLoss, Loss | |
class MasteringStyleTransfer: | |
def __init__(self, args): | |
self.args = args | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load models | |
self.effects_encoder = self.load_effects_encoder() | |
self.mastering_converter = self.load_mastering_converter() | |
def load_effects_encoder(self): | |
effects_encoder = Effects_Encoder(self.args.cfg_enc) | |
reload_weights(effects_encoder, self.args.encoder_path, self.device) | |
effects_encoder.to(self.device) | |
effects_encoder.eval() | |
return effects_encoder | |
def load_mastering_converter(self): | |
mastering_converter = Dasp_Mastering_Style_Transfer(num_features=2048, | |
sample_rate=self.args.sample_rate, | |
tgt_fx_names=['eq', 'distortion', 'multiband_comp', 'gain', 'imager', 'limiter'], | |
model_type='tcn', | |
config=self.args.cfg_converter, | |
batch_size=1) | |
reload_weights(mastering_converter, self.args.model_path, self.device) | |
mastering_converter.to(self.device) | |
mastering_converter.eval() | |
return mastering_converter | |
def get_reference_embedding(self, reference_tensor): | |
with torch.no_grad(): | |
reference_feature = self.effects_encoder(reference_tensor) | |
return reference_feature | |
def mastering_style_transfer(self, input_tensor, reference_feature): | |
with torch.no_grad(): | |
output_audio = self.mastering_converter(input_tensor, reference_feature) | |
predicted_params = self.mastering_converter.get_last_predicted_params() | |
return output_audio, predicted_params | |
def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature): | |
fit_embedding = torch.nn.Parameter(initial_reference_feature) | |
optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate']) | |
af_loss = AudioFeatureLoss( | |
weights=ito_config['af_weights'], | |
sample_rate=ito_config['sample_rate'], | |
stem_separation=False, | |
use_clap=False | |
) | |
min_loss = float('inf') | |
min_loss_step = 0 | |
min_loss_output = None | |
min_loss_params = None | |
min_loss_embedding = None | |
loss_history = [] | |
divergence_counter = 0 | |
for step in range(ito_config['num_steps']): | |
optimizer.zero_grad() | |
output_audio = self.mastering_converter(input_tensor, fit_embedding) | |
losses = af_loss(output_audio, reference_tensor) | |
total_loss = sum(losses.values()) | |
loss_history.append(total_loss.item()) | |
if total_loss < min_loss: | |
min_loss = total_loss.item() | |
min_loss_step = step | |
min_loss_output = output_audio.detach() | |
min_loss_params = self.mastering_converter.get_last_predicted_params() | |
min_loss_embedding = fit_embedding.detach().clone() | |
# Check for divergence | |
if len(loss_history) > 10 and total_loss > loss_history[-11]: | |
divergence_counter += 1 | |
else: | |
divergence_counter = 0 | |
print(total_loss, min_loss) | |
if divergence_counter >= 10: | |
print(f"Optimization stopped early due to divergence at step {step}") | |
break | |
total_loss.backward() | |
optimizer.step() | |
return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1 | |
def process_audio(self, input_path, reference_path, ito_reference_path, ito_config, perform_ito): | |
input_audio, sr = sf.read(input_path) | |
reference_audio, _ = sf.read(reference_path) | |
ito_reference_audio, _ = sf.read(ito_reference_path) | |
input_audio, reference_audio, ito_reference_audio = [ | |
np.stack([audio, audio]) if audio.ndim == 1 else audio.transpose(1,0) | |
for audio in [input_audio, reference_audio, ito_reference_audio] | |
] | |
input_tensor = torch.FloatTensor(input_audio).unsqueeze(0).to(self.device) | |
reference_tensor = torch.FloatTensor(reference_audio).unsqueeze(0).to(self.device) | |
ito_reference_tensor = torch.FloatTensor(ito_reference_audio).unsqueeze(0).to(self.device) | |
reference_feature = self.get_reference_embedding(reference_tensor) | |
output_audio, predicted_params = self.mastering_style_transfer(input_tensor, reference_feature) | |
if perform_ito: | |
ito_output_audio, ito_predicted_params, optimized_reference_feature, ito_steps = self.inference_time_optimization( | |
input_tensor, ito_reference_tensor, ito_config, reference_feature | |
) | |
ito_output_audio = ito_output_audio.squeeze().cpu().numpy() | |
print("\nDifference between initial and ITO predicted parameters:") | |
self.print_param_difference(predicted_params, ito_predicted_params) | |
else: | |
ito_output_audio, ito_predicted_params, optimized_reference_feature, ito_steps = None, None, None, None | |
output_audio = output_audio.squeeze().cpu().numpy() | |
return output_audio, predicted_params, ito_output_audio, ito_predicted_params, optimized_reference_feature, sr, ito_steps | |
def print_param_difference(self, initial_params, ito_params): | |
all_diffs = [] | |
print("\nAll parameter differences:") | |
for fx_name in initial_params.keys(): | |
print(f"\n{fx_name.upper()}:") | |
if isinstance(initial_params[fx_name], dict): | |
for param_name in initial_params[fx_name].keys(): | |
initial_value = initial_params[fx_name][param_name] | |
ito_value = ito_params[fx_name][param_name] | |
# Calculate normalized difference | |
param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name] | |
normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0])) | |
all_diffs.append((fx_name, param_name, initial_value, ito_value, normalized_diff)) | |
print(f" {param_name}:") | |
print(f" Initial: {initial_value.item():.4f}") | |
print(f" ITO: {ito_value.item():.4f}") | |
print(f" Normalized Diff: {normalized_diff.item():.4f}") | |
else: | |
initial_value = initial_params[fx_name] | |
ito_value = ito_params[fx_name] | |
# For 'imager', assume range is 0 to 1 | |
normalized_diff = abs(ito_value - initial_value) | |
all_diffs.append((fx_name, 'width', initial_value, ito_value, normalized_diff)) | |
print(f" width:") | |
print(f" Initial: {initial_value.item():.4f}") | |
print(f" ITO: {ito_value.item():.4f}") | |
print(f" Normalized Diff: {normalized_diff.item():.4f}") | |
# Sort differences by normalized difference and get top 10 | |
top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:10] | |
print("\nTop 10 parameter differences (sorted by normalized difference):") | |
for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs: | |
print(f"{fx_name.upper()} - {param_name}:") | |
print(f" Initial: {initial_value.item():.4f}") | |
print(f" ITO: {ito_value.item():.4f}") | |
print(f" Normalized Diff: {normalized_diff.item():.4f}") | |
print() | |
def print_predicted_params(self, predicted_params): | |
if predicted_params is None: | |
print("No predicted parameters available.") | |
return | |
print("Predicted Parameters:") | |
for fx_name, fx_params in predicted_params.items(): | |
print(f"\n{fx_name.upper()}:") | |
if isinstance(fx_params, dict): | |
for param_name, param_value in fx_params.items(): | |
if isinstance(param_value, torch.Tensor): | |
param_value = param_value.detach().cpu().numpy() | |
print(f" {param_name}: {param_value}") | |
elif isinstance(fx_params, torch.Tensor): | |
param_value = fx_params.detach().cpu().numpy() | |
print(f" {param_value}") | |
else: | |
print(f" {fx_params}") | |
def get_param_output_string(self, params): | |
if params is None: | |
return "No parameters available" | |
output = [] | |
for fx_name, fx_params in params.items(): | |
output.append(f"{fx_name.upper()}:") | |
if isinstance(fx_params, dict): | |
for param_name, param_value in fx_params.items(): | |
if isinstance(param_value, torch.Tensor): | |
param_value = param_value.item() | |
output.append(f" {param_name}: {param_value:.4f}") | |
elif isinstance(fx_params, torch.Tensor): | |
output.append(f" {fx_params.item():.4f}") | |
else: | |
output.append(f" {fx_params:.4f}") | |
return "\n".join(output) | |
def get_top_10_diff_string(self, initial_params, ito_params): | |
if initial_params is None or ito_params is None: | |
return "Cannot compare parameters" | |
all_diffs = [] | |
for fx_name in initial_params.keys(): | |
if isinstance(initial_params[fx_name], dict): | |
for param_name in initial_params[fx_name].keys(): | |
initial_value = initial_params[fx_name][param_name] | |
ito_value = ito_params[fx_name][param_name] | |
param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name] | |
normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0])) | |
all_diffs.append((fx_name, param_name, initial_value.item(), ito_value.item(), normalized_diff.item())) | |
else: | |
initial_value = initial_params[fx_name] | |
ito_value = ito_params[fx_name] | |
normalized_diff = abs(ito_value - initial_value) | |
all_diffs.append((fx_name, 'width', initial_value.item(), ito_value.item(), normalized_diff.item())) | |
top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:10] | |
output = ["Top 10 parameter differences (sorted by normalized difference):"] | |
for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs: | |
output.append(f"{fx_name.upper()} - {param_name}:") | |
output.append(f" Initial: {initial_value:.4f}") | |
output.append(f" ITO: {ito_value:.4f}") | |
output.append(f" Normalized Diff: {normalized_diff:.4f}") | |
output.append("") | |
return "\n".join(output) | |
def reload_weights(model, ckpt_path, device): | |
checkpoint = torch.load(ckpt_path, map_location=device) | |
from collections import OrderedDict | |
new_state_dict = OrderedDict() | |
for k, v in checkpoint["model"].items(): | |
name = k[7:] # remove `module.` | |
new_state_dict[name] = v | |
model.load_state_dict(new_state_dict, strict=False) | |
if __name__ == "__main__": | |
basis_path = '/data2/tony/Mastering_Style_Transfer/results/dasp_tcn_tuneenc_daspman_loudnessnorm/ckpt/1000/' | |
parser = argparse.ArgumentParser(description="Mastering Style Transfer") | |
parser.add_argument("--input_path", type=str, required=True, help="Path to input audio file") | |
parser.add_argument("--reference_path", type=str, required=True, help="Path to reference audio file") | |
parser.add_argument("--ito_reference_path", type=str, required=True, help="Path to ITO reference audio file") | |
parser.add_argument("--model_path", type=str, default=f"{basis_path}dasp_tcn_tuneenc_daspman_loudnessnorm_mastering_converter_1000.pt", help="Path to mastering converter model") | |
parser.add_argument("--encoder_path", type=str, default=f"{basis_path}dasp_tcn_tuneenc_daspman_loudnessnorm_effects_encoder_1000.pt", help="Path to effects encoder model") | |
parser.add_argument("--perform_ito", action="store_true", help="Whether to perform ITO") | |
parser.add_argument("--optimizer", type=str, default="RAdam", help="Optimizer for ITO") | |
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for ITO") | |
parser.add_argument("--num_steps", type=int, default=100, help="Number of optimization steps for ITO") | |
parser.add_argument("--af_weights", nargs='+', type=float, default=[0.1, 0.001, 1.0, 1.0, 0.1], help="Weights for AudioFeatureLoss") | |
parser.add_argument("--sample_rate", type=int, default=44100, help="Sample rate for AudioFeatureLoss") | |
parser.add_argument("--path_to_config", type=str, default='/home/tony/mastering_transfer/networks/configs.yaml', help="Path to network architecture configuration file") | |
args = parser.parse_args() | |
# load network configurations | |
with open(args.path_to_config, 'r') as f: | |
configs = yaml.full_load(f) | |
args.cfg_converter = configs['TCN']['param_mapping'] | |
args.cfg_enc = configs['Effects_Encoder']['default'] | |
ito_config = { | |
'optimizer': args.optimizer, | |
'learning_rate': args.learning_rate, | |
'num_steps': args.num_steps, | |
'af_weights': args.af_weights, | |
'sample_rate': args.sample_rate | |
} | |
mastering_style_transfer = MasteringStyleTransfer(args) | |
output_audio, predicted_params, ito_output_audio, ito_predicted_params, optimized_reference_feature, sr, ito_steps = mastering_style_transfer.process_audio( | |
args.input_path, args.reference_path, args.ito_reference_path, ito_config, args.perform_ito | |
) | |
# Save the output audio | |
sf.write("output_mastered.wav", output_audio.T, sr) | |
if ito_output_audio is not None: | |
sf.write("ito_output_mastered.wav", ito_output_audio.T, sr) | |