jhtonyKoo commited on
Commit
e298cbd
1 Parent(s): cac2c49

update loss

Browse files
app.py CHANGED
@@ -243,9 +243,9 @@ with gr.Blocks() as demo:
243
 
244
  def update_clap_options(loss_function):
245
  if loss_function == "CLAPFeatureLoss":
246
- return gr.update(visible=True), gr.update(visible=True)
247
  else:
248
- return gr.update(visible=False), gr.update(visible=False)
249
 
250
  loss_function.change(
251
  update_clap_options,
@@ -261,7 +261,7 @@ with gr.Blocks() as demo:
261
  inputs=[clap_target_type],
262
  outputs=[clap_text_prompt]
263
  )
264
-
265
  ito_button = gr.Button("Perform ITO")
266
 
267
  with gr.Row():
 
243
 
244
  def update_clap_options(loss_function):
245
  if loss_function == "CLAPFeatureLoss":
246
+ return gr.update(visible=False), gr.update(visible=True)
247
  else:
248
+ return gr.update(visible=True), gr.update(visible=False)
249
 
250
  loss_function.change(
251
  update_clap_options,
 
261
  inputs=[clap_target_type],
262
  outputs=[clap_text_prompt]
263
  )
264
+
265
  ito_button = gr.Button("Perform ITO")
266
 
267
  with gr.Row():
inference.py CHANGED
@@ -91,13 +91,14 @@ class MasteringStyleTransfer:
91
  # Compute loss
92
  if ito_config['loss_function'] == 'AudioFeatureLoss':
93
  losses = af_loss(output_audio, reference_tensor)
 
94
  elif ito_config['loss_function'] == 'CLAPFeatureLoss':
95
  if ito_config['clap_target_type'] == 'Audio':
96
  target = reference_tensor
97
  else:
98
  target = ito_config['clap_text_prompt']
99
  losses = self.clap_loss(output_audio, target, self.args.sample_rate)
100
- total_loss = sum(losses.values())
101
 
102
  if total_loss < min_loss:
103
  min_loss = total_loss.item()
 
91
  # Compute loss
92
  if ito_config['loss_function'] == 'AudioFeatureLoss':
93
  losses = af_loss(output_audio, reference_tensor)
94
+ total_loss = sum(losses.values())
95
  elif ito_config['loss_function'] == 'CLAPFeatureLoss':
96
  if ito_config['clap_target_type'] == 'Audio':
97
  target = reference_tensor
98
  else:
99
  target = ito_config['clap_text_prompt']
100
  losses = self.clap_loss(output_audio, target, self.args.sample_rate)
101
+ total_loss = losses
102
 
103
  if total_loss < min_loss:
104
  min_loss = total_loss.item()
modules/__pycache__/loss.cpython-311.pyc CHANGED
Binary files a/modules/__pycache__/loss.cpython-311.pyc and b/modules/__pycache__/loss.cpython-311.pyc differ
 
modules/loss.py CHANGED
@@ -520,3 +520,4 @@ if __name__ == "__main__":
520
  loss = clap_loss(input_audio, target_text, sample_rate)
521
  print(loss)
522
 
 
 
520
  loss = clap_loss(input_audio, target_text, sample_rate)
521
  print(loss)
522
 
523
+ print(loss.item())