jhtonyKoo commited on
Commit
c2d5e4a
1 Parent(s): e298cbd

update loss

Browse files
Files changed (3) hide show
  1. app.py +7 -10
  2. inference.py +2 -2
  3. modules/loss.py +19 -32
app.py CHANGED
@@ -94,8 +94,7 @@ def process_audio(input_audio, reference_audio):
94
 
95
  return (sr, output_audio), param_output, (sr, normalized_input)
96
 
97
- # def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights):
98
- def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt):
99
  if ito_reference_audio is None:
100
  ito_reference_audio = reference_audio
101
  af_weights = [float(w.strip()) for w in af_weights.split(',')]
@@ -108,7 +107,8 @@ def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, op
108
  'sample_rate': args.sample_rate,
109
  'loss_function': loss_function,
110
  'clap_target_type': clap_target_type,
111
- 'clap_text_prompt': clap_text_prompt
 
112
  }
113
 
114
  input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate)
@@ -163,7 +163,7 @@ with gr.Blocks() as demo:
163
  gr.Markdown("Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. \
164
  The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. \
165
  Perform mastering style transfer with an input source audio and a reference mastering style audio. On top of this result, you can perform ITO to optimize the reference embedding $z_{ref}$ to further gain control over the output mastering style.")
166
- gr.Image("ito_snow.png", width=300)
167
 
168
  gr.Markdown("## Step 1: Mastering Style Transfer")
169
 
@@ -219,14 +219,10 @@ with gr.Blocks() as demo:
219
  with gr.Row():
220
  ito_reference_audio = gr.Audio(label="ITO Reference Style Audio $x'_{ref}$ (optional)")
221
  with gr.Column():
222
- num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps")
223
  optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
224
  learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
225
  loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss")
226
-
227
- # af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1")
228
- # clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio", visible=False)
229
- # clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
230
 
231
  # Audio Feature Loss weights
232
  with gr.Column(visible=True) as audio_feature_weights:
@@ -240,6 +236,7 @@ with gr.Blocks() as demo:
240
  with gr.Column(visible=False) as clap_options:
241
  clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio")
242
  clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
 
243
 
244
  def update_clap_options(loss_function):
245
  if loss_function == "CLAPFeatureLoss":
@@ -285,7 +282,7 @@ with gr.Blocks() as demo:
285
 
286
  ito_button.click(
287
  perform_ito,
288
- inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt],
289
  outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results]
290
  ).then(
291
  update_ito_output,
 
94
 
95
  return (sr, output_audio), param_output, (sr, normalized_input)
96
 
97
+ def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt, clap_distance_fn):
 
98
  if ito_reference_audio is None:
99
  ito_reference_audio = reference_audio
100
  af_weights = [float(w.strip()) for w in af_weights.split(',')]
 
107
  'sample_rate': args.sample_rate,
108
  'loss_function': loss_function,
109
  'clap_target_type': clap_target_type,
110
+ 'clap_text_prompt': clap_text_prompt,
111
+ 'clap_distance_fn': clap_distance_fn
112
  }
113
 
114
  input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate)
 
163
  gr.Markdown("Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. \
164
  The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. \
165
  Perform mastering style transfer with an input source audio and a reference mastering style audio. On top of this result, you can perform ITO to optimize the reference embedding $z_{ref}$ to further gain control over the output mastering style.")
166
+ gr.Image("ito_snow.png", width=100, label="ITO pipeline")
167
 
168
  gr.Markdown("## Step 1: Mastering Style Transfer")
169
 
 
219
  with gr.Row():
220
  ito_reference_audio = gr.Audio(label="ITO Reference Style Audio $x'_{ref}$ (optional)")
221
  with gr.Column():
222
+ num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps for additional optimization")
223
  optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
224
  learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
225
  loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss")
 
 
 
 
226
 
227
  # Audio Feature Loss weights
228
  with gr.Column(visible=True) as audio_feature_weights:
 
236
  with gr.Column(visible=False) as clap_options:
237
  clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio")
238
  clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
239
+ clap_distance_fn = gr.Dropdown([ "cosine", "mse", "l1"], label="CLAP Distance Function", value="cosine")
240
 
241
  def update_clap_options(loss_function):
242
  if loss_function == "CLAPFeatureLoss":
 
282
 
283
  ito_button.click(
284
  perform_ito,
285
+ inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights, loss_function, clap_target_type, clap_text_prompt, clap_distance_fn],
286
  outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results]
287
  ).then(
288
  update_ito_output,
inference.py CHANGED
@@ -35,7 +35,7 @@ class MasteringStyleTransfer:
35
  STEMS=['mixture'], \
36
  EFFECTS=['eq', 'imager', 'loudness'])
37
  # Loss functions
38
- self.clap_loss = CLAPFeatureLoss(distance_fn='cosine')
39
 
40
  def load_effects_encoder(self):
41
  effects_encoder = Effects_Encoder(self.args.cfg_enc)
@@ -97,7 +97,7 @@ class MasteringStyleTransfer:
97
  target = reference_tensor
98
  else:
99
  target = ito_config['clap_text_prompt']
100
- losses = self.clap_loss(output_audio, target, self.args.sample_rate)
101
  total_loss = losses
102
 
103
  if total_loss < min_loss:
 
35
  STEMS=['mixture'], \
36
  EFFECTS=['eq', 'imager', 'loudness'])
37
  # Loss functions
38
+ self.clap_loss = CLAPFeatureLoss()
39
 
40
  def load_effects_encoder(self):
41
  effects_encoder = Effects_Encoder(self.args.cfg_enc)
 
97
  target = reference_tensor
98
  else:
99
  target = ito_config['clap_text_prompt']
100
+ losses = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
101
  total_loss = losses
102
 
103
  if total_loss < min_loss:
modules/loss.py CHANGED
@@ -180,23 +180,13 @@ import laion_clap
180
  import torchaudio
181
  # CLAP feature loss
182
  class CLAPFeatureLoss(nn.Module):
183
- def __init__(self, distance_fn='mse'):
184
  super(CLAPFeatureLoss, self).__init__()
185
  self.target_sample_rate = 48000 # CLAP expects 48kHz audio
186
  self.model = laion_clap.CLAP_Module(enable_fusion=False)
187
  self.model.load_ckpt() # download the default pretrained checkpoint
188
-
189
- self.distance_fn = distance_fn
190
- if distance_fn == 'mse':
191
- self.compute_distance = F.mse_loss
192
- elif distance_fn == 'l1':
193
- self.compute_distance = F.l1_loss
194
- elif distance_fn == 'cosine':
195
- self.compute_distance = lambda x, y: 1 - F.cosine_similarity(x, y).mean()
196
- else:
197
- raise ValueError(f"Unsupported distance function: {distance_fn}")
198
 
199
- def forward(self, input_audio, target, sample_rate):
200
  # Process input audio
201
  input_embed = self.process_audio(input_audio, sample_rate)
202
 
@@ -209,7 +199,7 @@ class CLAPFeatureLoss(nn.Module):
209
  raise ValueError("Target must be either audio tensor or text (string or list of strings)")
210
 
211
  # Compute loss using the specified distance function
212
- loss = self.compute_distance(input_embed, target_embed)
213
 
214
  return loss
215
 
@@ -230,7 +220,8 @@ class CLAPFeatureLoss(nn.Module):
230
  audio = self.quantize(audio)
231
 
232
  # Get CLAP embeddings
233
- embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
 
234
  return embed
235
 
236
  def process_text(self, text):
@@ -238,18 +229,29 @@ class CLAPFeatureLoss(nn.Module):
238
  # ensure input is a list of strings
239
  if not isinstance(text, list):
240
  text = [text]
241
- embed = self.model.get_text_embedding(text, use_tensor=True)
 
242
  return embed
243
 
 
 
 
 
 
 
 
 
 
 
244
  def quantize(self, audio):
245
  audio = audio.squeeze(1) # Remove channel dimension
246
  audio = torch.clamp(audio, -1.0, 1.0)
247
  audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
248
  return audio
249
 
250
- def resample(self, audio, sample_rate):
251
  resampler = torchaudio.transforms.Resample(
252
- orig_freq=sample_rate, new_freq=self.target_sample_rate
253
  ).to(audio.device)
254
  return resampler(audio)
255
 
@@ -506,18 +508,3 @@ class AudioFeatureLoss(torch.nn.Module):
506
 
507
  return losses
508
 
509
-
510
-
511
- if __name__ == "__main__":
512
- clap_loss = CLAPFeatureLoss(distance_fn='cosine')
513
-
514
- input_audio = torch.randn(1, 2, 44100)
515
- target_audio = torch.randn(1, 2, 44100)
516
- target_text = "This is a test"
517
- sample_rate = 44100
518
- loss = clap_loss(input_audio, target_audio, sample_rate)
519
- print(loss)
520
- loss = clap_loss(input_audio, target_text, sample_rate)
521
- print(loss)
522
-
523
- print(loss.item())
 
180
  import torchaudio
181
  # CLAP feature loss
182
  class CLAPFeatureLoss(nn.Module):
183
+ def __init__(self):
184
  super(CLAPFeatureLoss, self).__init__()
185
  self.target_sample_rate = 48000 # CLAP expects 48kHz audio
186
  self.model = laion_clap.CLAP_Module(enable_fusion=False)
187
  self.model.load_ckpt() # download the default pretrained checkpoint
 
 
 
 
 
 
 
 
 
 
188
 
189
+ def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
190
  # Process input audio
191
  input_embed = self.process_audio(input_audio, sample_rate)
192
 
 
199
  raise ValueError("Target must be either audio tensor or text (string or list of strings)")
200
 
201
  # Compute loss using the specified distance function
202
+ loss = self.compute_distance(input_embed, target_embed, distance_fn)
203
 
204
  return loss
205
 
 
220
  audio = self.quantize(audio)
221
 
222
  # Get CLAP embeddings
223
+ with torch.no_grad():
224
+ embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
225
  return embed
226
 
227
  def process_text(self, text):
 
229
  # ensure input is a list of strings
230
  if not isinstance(text, list):
231
  text = [text]
232
+ with torch.no_grad():
233
+ embed = self.model.get_text_embedding(text, use_tensor=True)
234
  return embed
235
 
236
+ def compute_distance(self, x, y, distance_fn):
237
+ if distance_fn == 'mse':
238
+ return F.mse_loss(x, y)
239
+ elif distance_fn == 'l1':
240
+ return F.l1_loss(x, y)
241
+ elif distance_fn == 'cosine':
242
+ return 1 - F.cosine_similarity(x, y).mean()
243
+ else:
244
+ raise ValueError(f"Unsupported distance function: {distance_fn}")
245
+
246
  def quantize(self, audio):
247
  audio = audio.squeeze(1) # Remove channel dimension
248
  audio = torch.clamp(audio, -1.0, 1.0)
249
  audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
250
  return audio
251
 
252
+ def resample(self, audio, input_sample_rate):
253
  resampler = torchaudio.transforms.Resample(
254
+ orig_freq=input_sample_rate, new_freq=self.target_sample_rate
255
  ).to(audio.device)
256
  return resampler(audio)
257
 
 
508
 
509
  return losses
510