jhtonyKoo commited on
Commit
f9582e0
1 Parent(s): 6bd330e

modify app

Browse files
Files changed (1) hide show
  1. inference.py +36 -17
inference.py CHANGED
@@ -114,24 +114,43 @@ class MasteringStyleTransfer:
114
 
115
  return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def process_audio(self, input_audio, reference_audio, ito_reference_audio, params, perform_ito, log_ito=False):
118
- print(input_audio[1])
119
- input_audio[1], reference_audio[1], ito_reference_audio[1] = [
120
- np.stack([audio, audio]) if audio.ndim == 1 else audio.transpose(1,0)
121
- for audio in [input_audio, reference_audio, ito_reference_audio]
122
- ]
123
-
124
- input_tensor = torch.FloatTensor(input_audio).unsqueeze(0).to(self.device)
125
- reference_tensor = torch.FloatTensor(reference_audio).unsqueeze(0).to(self.device)
126
- ito_reference_tensor = torch.FloatTensor(ito_reference_audio).unsqueeze(0).to(self.device)
127
-
128
- #resample to 44.1kHz if necessary
129
- if input_audio[0] != self.args.sample_rate:
130
- input_tensor = convert_audio(input_tensor, input_audio[0], self.args.sample_rate, 2)
131
- if reference_audio[0] != self.args.sample_rate:
132
- reference_tensor = convert_audio(reference_tensor, reference_audio[0], self.args.sample_rate, 2)
133
- if ito_reference_audio[0] != self.args.sample_rate:
134
- ito_reference_tensor = convert_audio(ito_reference_tensor, ito_reference_audio[0], self.args.sample_rate, 2)
135
 
136
  reference_feature = self.get_reference_embedding(reference_tensor)
137
 
 
114
 
115
  return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1
116
 
117
+ def preprocess_audio(self, audio, target_sample_rate=44100):
118
+ sample_rate, data = audio
119
+
120
+ # Normalize audio to -1 to 1 range
121
+ if data.dtype == np.int16:
122
+ data = data.astype(np.float32) / 32768.0
123
+ elif data.dtype == np.float32:
124
+ data = np.clip(data, -1.0, 1.0)
125
+ else:
126
+ raise ValueError(f"Unsupported audio data type: {data.dtype}")
127
+
128
+ # Ensure stereo channels
129
+ if data.ndim == 1:
130
+ data = np.stack([data, data])
131
+ elif data.ndim == 2:
132
+ if data.shape[0] == 2:
133
+ pass # Already in correct shape
134
+ elif data.shape[1] == 2:
135
+ data = data.T
136
+ else:
137
+ data = np.stack([data[:, 0], data[:, 0]]) # Duplicate mono channel
138
+ else:
139
+ raise ValueError(f"Unsupported audio shape: {data.shape}")
140
+
141
+ # Convert to torch tensor
142
+ data_tensor = torch.FloatTensor(data).unsqueeze(0)
143
+
144
+ # Resample if necessary
145
+ if sample_rate != target_sample_rate:
146
+ data_tensor = julius.resample_frac(data_tensor, sample_rate, target_sample_rate)
147
+
148
+ return data_tensor.to(self.device)
149
+
150
  def process_audio(self, input_audio, reference_audio, ito_reference_audio, params, perform_ito, log_ito=False):
151
+ input_tensor = self.preprocess_audio(input_audio, self.args.sample_rate)
152
+ reference_tensor = self.preprocess_audio(reference_audio, self.args.sample_rate)
153
+ ito_reference_tensor = self.preprocess_audio(ito_reference_audio, self.args.sample_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  reference_feature = self.get_reference_embedding(reference_tensor)
156