Spaces:
Running
Running
update loss
Browse files- app.py +7 -10
- inference.py +2 -2
- modules/loss.py +19 -32
app.py
CHANGED
@@ -94,8 +94,7 @@ def process_audio(input_audio, reference_audio):
|
|
94 |
|
95 |
return (sr, output_audio), param_output, (sr, normalized_input)
|
96 |
|
97 |
-
|
98 |
-
def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt):
|
99 |
if ito_reference_audio is None:
|
100 |
ito_reference_audio = reference_audio
|
101 |
af_weights = [float(w.strip()) for w in af_weights.split(',')]
|
@@ -108,7 +107,8 @@ def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, op
|
|
108 |
'sample_rate': args.sample_rate,
|
109 |
'loss_function': loss_function,
|
110 |
'clap_target_type': clap_target_type,
|
111 |
-
'clap_text_prompt': clap_text_prompt
|
|
|
112 |
}
|
113 |
|
114 |
input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate)
|
@@ -163,7 +163,7 @@ with gr.Blocks() as demo:
|
|
163 |
gr.Markdown("Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. \
|
164 |
The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. \
|
165 |
Perform mastering style transfer with an input source audio and a reference mastering style audio. On top of this result, you can perform ITO to optimize the reference embedding $z_{ref}$ to further gain control over the output mastering style.")
|
166 |
-
gr.Image("ito_snow.png", width=
|
167 |
|
168 |
gr.Markdown("## Step 1: Mastering Style Transfer")
|
169 |
|
@@ -219,14 +219,10 @@ with gr.Blocks() as demo:
|
|
219 |
with gr.Row():
|
220 |
ito_reference_audio = gr.Audio(label="ITO Reference Style Audio $x'_{ref}$ (optional)")
|
221 |
with gr.Column():
|
222 |
-
num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps")
|
223 |
optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
|
224 |
learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
|
225 |
loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss")
|
226 |
-
|
227 |
-
# af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1")
|
228 |
-
# clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio", visible=False)
|
229 |
-
# clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
|
230 |
|
231 |
# Audio Feature Loss weights
|
232 |
with gr.Column(visible=True) as audio_feature_weights:
|
@@ -240,6 +236,7 @@ with gr.Blocks() as demo:
|
|
240 |
with gr.Column(visible=False) as clap_options:
|
241 |
clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio")
|
242 |
clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
|
|
|
243 |
|
244 |
def update_clap_options(loss_function):
|
245 |
if loss_function == "CLAPFeatureLoss":
|
@@ -285,7 +282,7 @@ with gr.Blocks() as demo:
|
|
285 |
|
286 |
ito_button.click(
|
287 |
perform_ito,
|
288 |
-
inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt],
|
289 |
outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results]
|
290 |
).then(
|
291 |
update_ito_output,
|
|
|
94 |
|
95 |
return (sr, output_audio), param_output, (sr, normalized_input)
|
96 |
|
97 |
+
def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt, clap_distance_fn):
|
|
|
98 |
if ito_reference_audio is None:
|
99 |
ito_reference_audio = reference_audio
|
100 |
af_weights = [float(w.strip()) for w in af_weights.split(',')]
|
|
|
107 |
'sample_rate': args.sample_rate,
|
108 |
'loss_function': loss_function,
|
109 |
'clap_target_type': clap_target_type,
|
110 |
+
'clap_text_prompt': clap_text_prompt,
|
111 |
+
'clap_distance_fn': clap_distance_fn
|
112 |
}
|
113 |
|
114 |
input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate)
|
|
|
163 |
gr.Markdown("Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. \
|
164 |
The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. \
|
165 |
Perform mastering style transfer with an input source audio and a reference mastering style audio. On top of this result, you can perform ITO to optimize the reference embedding $z_{ref}$ to further gain control over the output mastering style.")
|
166 |
+
gr.Image("ito_snow.png", width=100, label="ITO pipeline")
|
167 |
|
168 |
gr.Markdown("## Step 1: Mastering Style Transfer")
|
169 |
|
|
|
219 |
with gr.Row():
|
220 |
ito_reference_audio = gr.Audio(label="ITO Reference Style Audio $x'_{ref}$ (optional)")
|
221 |
with gr.Column():
|
222 |
+
num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps for additional optimization")
|
223 |
optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
|
224 |
learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
|
225 |
loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss")
|
|
|
|
|
|
|
|
|
226 |
|
227 |
# Audio Feature Loss weights
|
228 |
with gr.Column(visible=True) as audio_feature_weights:
|
|
|
236 |
with gr.Column(visible=False) as clap_options:
|
237 |
clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio")
|
238 |
clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
|
239 |
+
clap_distance_fn = gr.Dropdown([ "cosine", "mse", "l1"], label="CLAP Distance Function", value="cosine")
|
240 |
|
241 |
def update_clap_options(loss_function):
|
242 |
if loss_function == "CLAPFeatureLoss":
|
|
|
282 |
|
283 |
ito_button.click(
|
284 |
perform_ito,
|
285 |
+
inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt, clap_distance_fn],
|
286 |
outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results]
|
287 |
).then(
|
288 |
update_ito_output,
|
inference.py
CHANGED
@@ -35,7 +35,7 @@ class MasteringStyleTransfer:
|
|
35 |
STEMS=['mixture'], \
|
36 |
EFFECTS=['eq', 'imager', 'loudness'])
|
37 |
# Loss functions
|
38 |
-
self.clap_loss = CLAPFeatureLoss(
|
39 |
|
40 |
def load_effects_encoder(self):
|
41 |
effects_encoder = Effects_Encoder(self.args.cfg_enc)
|
@@ -97,7 +97,7 @@ class MasteringStyleTransfer:
|
|
97 |
target = reference_tensor
|
98 |
else:
|
99 |
target = ito_config['clap_text_prompt']
|
100 |
-
losses = self.clap_loss(output_audio, target, self.args.sample_rate)
|
101 |
total_loss = losses
|
102 |
|
103 |
if total_loss < min_loss:
|
|
|
35 |
STEMS=['mixture'], \
|
36 |
EFFECTS=['eq', 'imager', 'loudness'])
|
37 |
# Loss functions
|
38 |
+
self.clap_loss = CLAPFeatureLoss()
|
39 |
|
40 |
def load_effects_encoder(self):
|
41 |
effects_encoder = Effects_Encoder(self.args.cfg_enc)
|
|
|
97 |
target = reference_tensor
|
98 |
else:
|
99 |
target = ito_config['clap_text_prompt']
|
100 |
+
losses = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
|
101 |
total_loss = losses
|
102 |
|
103 |
if total_loss < min_loss:
|
modules/loss.py
CHANGED
@@ -180,23 +180,13 @@ import laion_clap
|
|
180 |
import torchaudio
|
181 |
# CLAP feature loss
|
182 |
class CLAPFeatureLoss(nn.Module):
|
183 |
-
def __init__(self
|
184 |
super(CLAPFeatureLoss, self).__init__()
|
185 |
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
186 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
187 |
self.model.load_ckpt() # download the default pretrained checkpoint
|
188 |
-
|
189 |
-
self.distance_fn = distance_fn
|
190 |
-
if distance_fn == 'mse':
|
191 |
-
self.compute_distance = F.mse_loss
|
192 |
-
elif distance_fn == 'l1':
|
193 |
-
self.compute_distance = F.l1_loss
|
194 |
-
elif distance_fn == 'cosine':
|
195 |
-
self.compute_distance = lambda x, y: 1 - F.cosine_similarity(x, y).mean()
|
196 |
-
else:
|
197 |
-
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
198 |
|
199 |
-
def forward(self, input_audio, target, sample_rate):
|
200 |
# Process input audio
|
201 |
input_embed = self.process_audio(input_audio, sample_rate)
|
202 |
|
@@ -209,7 +199,7 @@ class CLAPFeatureLoss(nn.Module):
|
|
209 |
raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
210 |
|
211 |
# Compute loss using the specified distance function
|
212 |
-
loss = self.compute_distance(input_embed, target_embed)
|
213 |
|
214 |
return loss
|
215 |
|
@@ -230,7 +220,8 @@ class CLAPFeatureLoss(nn.Module):
|
|
230 |
audio = self.quantize(audio)
|
231 |
|
232 |
# Get CLAP embeddings
|
233 |
-
|
|
|
234 |
return embed
|
235 |
|
236 |
def process_text(self, text):
|
@@ -238,18 +229,29 @@ class CLAPFeatureLoss(nn.Module):
|
|
238 |
# ensure input is a list of strings
|
239 |
if not isinstance(text, list):
|
240 |
text = [text]
|
241 |
-
|
|
|
242 |
return embed
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
def quantize(self, audio):
|
245 |
audio = audio.squeeze(1) # Remove channel dimension
|
246 |
audio = torch.clamp(audio, -1.0, 1.0)
|
247 |
audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
248 |
return audio
|
249 |
|
250 |
-
def resample(self, audio,
|
251 |
resampler = torchaudio.transforms.Resample(
|
252 |
-
orig_freq=
|
253 |
).to(audio.device)
|
254 |
return resampler(audio)
|
255 |
|
@@ -506,18 +508,3 @@ class AudioFeatureLoss(torch.nn.Module):
|
|
506 |
|
507 |
return losses
|
508 |
|
509 |
-
|
510 |
-
|
511 |
-
if __name__ == "__main__":
|
512 |
-
clap_loss = CLAPFeatureLoss(distance_fn='cosine')
|
513 |
-
|
514 |
-
input_audio = torch.randn(1, 2, 44100)
|
515 |
-
target_audio = torch.randn(1, 2, 44100)
|
516 |
-
target_text = "This is a test"
|
517 |
-
sample_rate = 44100
|
518 |
-
loss = clap_loss(input_audio, target_audio, sample_rate)
|
519 |
-
print(loss)
|
520 |
-
loss = clap_loss(input_audio, target_text, sample_rate)
|
521 |
-
print(loss)
|
522 |
-
|
523 |
-
print(loss.item())
|
|
|
180 |
import torchaudio
|
181 |
# CLAP feature loss
|
182 |
class CLAPFeatureLoss(nn.Module):
|
183 |
+
def __init__(self):
|
184 |
super(CLAPFeatureLoss, self).__init__()
|
185 |
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
186 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
187 |
self.model.load_ckpt() # download the default pretrained checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
+
def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
|
190 |
# Process input audio
|
191 |
input_embed = self.process_audio(input_audio, sample_rate)
|
192 |
|
|
|
199 |
raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
200 |
|
201 |
# Compute loss using the specified distance function
|
202 |
+
loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
203 |
|
204 |
return loss
|
205 |
|
|
|
220 |
audio = self.quantize(audio)
|
221 |
|
222 |
# Get CLAP embeddings
|
223 |
+
with torch.no_grad():
|
224 |
+
embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
|
225 |
return embed
|
226 |
|
227 |
def process_text(self, text):
|
|
|
229 |
# ensure input is a list of strings
|
230 |
if not isinstance(text, list):
|
231 |
text = [text]
|
232 |
+
with torch.no_grad():
|
233 |
+
embed = self.model.get_text_embedding(text, use_tensor=True)
|
234 |
return embed
|
235 |
|
236 |
+
def compute_distance(self, x, y, distance_fn):
|
237 |
+
if distance_fn == 'mse':
|
238 |
+
return F.mse_loss(x, y)
|
239 |
+
elif distance_fn == 'l1':
|
240 |
+
return F.l1_loss(x, y)
|
241 |
+
elif distance_fn == 'cosine':
|
242 |
+
return 1 - F.cosine_similarity(x, y).mean()
|
243 |
+
else:
|
244 |
+
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
245 |
+
|
246 |
def quantize(self, audio):
|
247 |
audio = audio.squeeze(1) # Remove channel dimension
|
248 |
audio = torch.clamp(audio, -1.0, 1.0)
|
249 |
audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
250 |
return audio
|
251 |
|
252 |
+
def resample(self, audio, input_sample_rate):
|
253 |
resampler = torchaudio.transforms.Resample(
|
254 |
+
orig_freq=input_sample_rate, new_freq=self.target_sample_rate
|
255 |
).to(audio.device)
|
256 |
return resampler(audio)
|
257 |
|
|
|
508 |
|
509 |
return losses
|
510 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|