Update inference/style_transfer.py
Browse files- inference/style_transfer.py +12 -9
inference/style_transfer.py
CHANGED
@@ -32,7 +32,8 @@ class Mixing_Style_Transfer_Inference:
|
|
32 |
self.device = torch.device("cuda:0")
|
33 |
else:
|
34 |
self.device = torch.device("cpu")
|
35 |
-
|
|
|
36 |
# inference computational hyperparameters
|
37 |
self.args = args
|
38 |
self.segment_length = args.segment_length
|
@@ -176,13 +177,14 @@ class Mixing_Style_Transfer_Inference:
|
|
176 |
fin_data_out_mix = sum(inst_outputs)
|
177 |
|
178 |
# loudness adjusting for mastering purpose
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
186 |
|
187 |
# save output
|
188 |
fin_output_path = os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav")
|
@@ -382,10 +384,11 @@ def set_up(start_point_in_second=0, duration_in_second=30):
|
|
382 |
# FX normalization
|
383 |
inference_args.add_argument('--normalize_input', type=str2bool, default=True)
|
384 |
inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
|
|
|
385 |
# interpolation
|
386 |
inference_args.add_argument('--interpolation', type=str2bool, default=False)
|
387 |
inference_args.add_argument('--interpolate_segments', type=int, default=30)
|
388 |
-
|
389 |
device_args = parser.add_argument_group('Device args')
|
390 |
device_args.add_argument('--workers', type=int, default=1)
|
391 |
device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
|
|
|
32 |
self.device = torch.device("cuda:0")
|
33 |
else:
|
34 |
self.device = torch.device("cpu")
|
35 |
+
print(f"using device: {self.device} for inference")
|
36 |
+
|
37 |
# inference computational hyperparameters
|
38 |
self.args = args
|
39 |
self.segment_length = args.segment_length
|
|
|
177 |
fin_data_out_mix = sum(inst_outputs)
|
178 |
|
179 |
# loudness adjusting for mastering purpose
|
180 |
+
if self.args.match_output_loudness:
|
181 |
+
meter = pyloudnorm.Meter(44100)
|
182 |
+
loudness_out = meter.integrated_loudness(fin_data_out_mix.transpose(-1, -2))
|
183 |
+
reference_aud = load_wav_segment(reference_track_path, axis=1)
|
184 |
+
loudness_ref = meter.integrated_loudness(reference_aud)
|
185 |
+
# adjust output loudness to that of the reference
|
186 |
+
fin_data_out_mix = pyloudnorm.normalize.loudness(fin_data_out_mix, loudness_out, loudness_ref)
|
187 |
+
fin_data_out_mix = np.clip(fin_data_out_mix, -1., 1.)
|
188 |
|
189 |
# save output
|
190 |
fin_output_path = os.path.join(cur_out_dir, f"mixture_{output_name_tag}.wav")
|
|
|
384 |
# FX normalization
|
385 |
inference_args.add_argument('--normalize_input', type=str2bool, default=True)
|
386 |
inference_args.add_argument('--normalization_order', type=str2bool, default=['loudness', 'eq', 'compression', 'imager', 'loudness']) # Effects to be normalized, order matters
|
387 |
+
inference_args.add_argument('--match_output_loudness', type=str2bool, default=False)
|
388 |
# interpolation
|
389 |
inference_args.add_argument('--interpolation', type=str2bool, default=False)
|
390 |
inference_args.add_argument('--interpolate_segments', type=int, default=30)
|
391 |
+
|
392 |
device_args = parser.add_argument_group('Device args')
|
393 |
device_args.add_argument('--workers', type=int, default=1)
|
394 |
device_args.add_argument('--batch_size', type=int, default=1) # for processing long audio
|