Plachta commited on
Commit
c6d0958
β€’
1 Parent(s): 2ddbc30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -114
app.py CHANGED
@@ -7,7 +7,6 @@ import yaml
7
  from hf_utils import load_custom_model_from_hf
8
  import numpy as np
9
  from pydub import AudioSegment
10
- import spaces
11
 
12
  # Load model and configuration
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -15,6 +14,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
16
  "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
17
  "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
 
 
18
  config = yaml.safe_load(open(dit_config_path, 'r'))
19
  model_params = recursive_munch(config['model_params'])
20
  model = build_model(model_params, stage='DiT')
@@ -79,25 +80,14 @@ mel_fn_args = {
79
  "fmax": None,
80
  "center": False
81
  }
82
- mel_fn_args_f0 = {
83
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
84
- "win_size": config['preprocess_params']['spect_params']['win_length'],
85
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
86
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
87
- "sampling_rate": sr,
88
- "fmin": 0,
89
- "fmax": None,
90
- "center": False
91
- }
92
  from modules.audio import mel_spectrogram
93
 
94
  to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
95
- to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
96
 
97
  # f0 conditioned model
98
  dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
99
- "DiT_seed_v2_uvit_facodec_small_wavenet_f0_bigvgan_pruned.pth",
100
- "config_dit_mel_seed_facodec_small_wavenet_f0.yml")
101
 
102
  config = yaml.safe_load(open(dit_config_path, 'r'))
103
  model_params = recursive_munch(config['model_params'])
@@ -107,7 +97,7 @@ sr = config['preprocess_params']['sr']
107
 
108
  # Load checkpoints
109
  model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path,
110
- load_only_params=True, ignore_modules=[], is_distributed=False)
111
  for key in model_f0:
112
  model_f0[key].eval()
113
  model_f0[key].to(device)
@@ -119,28 +109,46 @@ from modules.rmvpe import RMVPE
119
  model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
120
  rmvpe = RMVPE(model_path, is_half=False, device=device)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def adjust_f0_semitones(f0_sequence, n_semitones):
123
  factor = 2 ** (n_semitones / 12)
124
  return f0_sequence * factor
125
 
126
  def crossfade(chunk1, chunk2, overlap):
127
- fade_out = np.linspace(1, 0, overlap)
128
- fade_in = np.linspace(0, 1, overlap)
129
  chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
130
  return chunk2
131
 
132
  # streaming and chunk processing related params
133
  max_context_window = sr // hop_length * 30
134
- overlap_frame_len = 64
135
  overlap_wave_len = overlap_frame_len * hop_length
136
  bitrate = "320k"
137
 
138
- @spaces.GPU
139
  @torch.no_grad()
140
  @torch.inference_mode()
141
  def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
142
  inference_module = model if not f0_condition else model_f0
143
  mel_fn = to_mel if not f0_condition else to_mel_f0
 
 
144
  # Load audio
145
  source_audio = librosa.load(source, sr=sr)[0]
146
  ref_audio = librosa.load(target, sr=sr)[0]
@@ -151,46 +159,35 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
151
 
152
  # Resample
153
  ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
154
-
155
- # Extract features
156
- if f0_condition:
157
- converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
158
- waves_input = converted_waves_24k.unsqueeze(1)
159
- max_wave_len_per_chunk = 24000 * 20
160
- wave_input_chunks = [
161
- waves_input[..., i:i + max_wave_len_per_chunk] for i in range(0, waves_input.size(-1), max_wave_len_per_chunk)
162
- ]
163
- S_alt_chunks = []
164
- for i, chunk in enumerate(wave_input_chunks):
165
- z = codec_encoder.encoder(chunk)
166
- (
167
- quantized,
168
- codes
169
- ) = codec_encoder.quantizer(
170
- z,
171
- chunk,
172
- )
173
- S_alt = torch.cat([codes[1], codes[0]], dim=1)
174
- S_alt_chunks.append(S_alt)
175
- S_alt = torch.cat(S_alt_chunks, dim=-1)
176
-
177
- # S_ori should be extracted in the same way
178
- waves_24k = torchaudio.functional.resample(ref_audio, sr, 24000)
179
- waves_input = waves_24k.unsqueeze(1)
180
- z = codec_encoder.encoder(waves_input)
181
- (
182
- quantized,
183
- codes
184
- ) = codec_encoder.quantizer(
185
- z,
186
- waves_input,
187
  )
188
- S_ori = torch.cat([codes[1], codes[0]], dim=1)
 
189
  else:
190
- converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
191
- # if source audio less than 30 seconds, whisper can handle in one forward
192
- if converted_waves_16k.size(-1) <= 16000 * 30:
193
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
 
 
 
 
 
 
194
  return_tensors="pt",
195
  return_attention_mask=True,
196
  sampling_rate=16000)
@@ -204,56 +201,31 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
204
  return_dict=True,
205
  )
206
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
207
- S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
208
- else:
209
- overlapping_time = 5 # 5 seconds
210
- S_alt_list = []
211
- buffer = None
212
- traversed_time = 0
213
- while traversed_time < converted_waves_16k.size(-1):
214
- if buffer is None: # first chunk
215
- chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
216
- else:
217
- chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
218
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
219
- return_tensors="pt",
220
- return_attention_mask=True,
221
- sampling_rate=16000)
222
- alt_input_features = whisper_model._mask_input_features(
223
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
224
- alt_outputs = whisper_model.encoder(
225
- alt_input_features.to(whisper_model.encoder.dtype),
226
- head_mask=None,
227
- output_attentions=False,
228
- output_hidden_states=False,
229
- return_dict=True,
230
- )
231
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
232
- S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
233
- if traversed_time == 0:
234
- S_alt_list.append(S_alt)
235
- else:
236
- S_alt_list.append(S_alt[:, 50 * overlapping_time:])
237
- buffer = chunk[:, -16000 * overlapping_time:]
238
- traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
239
- S_alt = torch.cat(S_alt_list, dim=1)
240
-
241
- ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
242
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
243
- return_tensors="pt",
244
- return_attention_mask=True)
245
- ori_input_features = whisper_model._mask_input_features(
246
- ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
247
- with torch.no_grad():
248
- ori_outputs = whisper_model.encoder(
249
- ori_input_features.to(whisper_model.encoder.dtype),
250
- head_mask=None,
251
- output_attentions=False,
252
- output_hidden_states=False,
253
- return_dict=True,
254
- )
255
- S_ori = ori_outputs.last_hidden_state.to(torch.float32)
256
- S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
257
 
258
  mel = mel_fn(source_audio.to(device).float())
259
  mel2 = mel_fn(ref_audio.to(device).float())
@@ -269,10 +241,8 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
269
  style2 = campplus_model(feat2.unsqueeze(0))
270
 
271
  if f0_condition:
272
- waves_16k = torchaudio.functional.resample(waves_24k, 24000, 16000)
273
- converted_waves_16k = torchaudio.functional.resample(converted_waves_24k, 24000, 16000)
274
- F0_ori = rmvpe.infer_from_audio(waves_16k[0], thred=0.03)
275
- F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.03)
276
 
277
  F0_ori = torch.from_numpy(F0_ori).to(device)[None]
278
  F0_alt = torch.from_numpy(F0_alt).to(device)[None]
@@ -285,8 +255,6 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
285
  voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
286
  median_log_f0_ori = torch.median(voiced_log_f0_ori)
287
  median_log_f0_alt = torch.median(voiced_log_f0_alt)
288
- # mean_log_f0_ori = torch.mean(voiced_log_f0_ori)
289
- # mean_log_f0_alt = torch.mean(voiced_log_f0_alt)
290
 
291
  # shift alt log f0 level to ori log f0 level
292
  shifted_log_f0_alt = log_f0_alt.clone()
@@ -319,7 +287,7 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
319
  mel2, style2, None, diffusion_steps,
320
  inference_cfg_rate=inference_cfg_rate)
321
  vc_target = vc_target[:, :, mel2.size(-1):]
322
- vc_wave = bigvgan_model(vc_target)[0]
323
  if processed_frames == 0:
324
  if is_last_chunk:
325
  output_wave = vc_wave[0].cpu().numpy()
 
7
  from hf_utils import load_custom_model_from_hf
8
  import numpy as np
9
  from pydub import AudioSegment
 
10
 
11
  # Load model and configuration
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
14
  dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
15
  "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
16
  "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
17
+ # dit_checkpoint_path = "E:/DiT_epoch_00018_step_801000.pth"
18
+ # dit_config_path = "configs/config_dit_mel_seed_uvit_whisper_small_encoder_wavenet.yml"
19
  config = yaml.safe_load(open(dit_config_path, 'r'))
20
  model_params = recursive_munch(config['model_params'])
21
  model = build_model(model_params, stage='DiT')
 
80
  "fmax": None,
81
  "center": False
82
  }
 
 
 
 
 
 
 
 
 
 
83
  from modules.audio import mel_spectrogram
84
 
85
  to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
 
86
 
87
  # f0 conditioned model
88
  dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
89
+ "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ema.pth",
90
+ "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
91
 
92
  config = yaml.safe_load(open(dit_config_path, 'r'))
93
  model_params = recursive_munch(config['model_params'])
 
97
 
98
  # Load checkpoints
99
  model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path,
100
+ load_only_params=True, ignore_modules=[], is_distributed=False, load_ema=True)
101
  for key in model_f0:
102
  model_f0[key].eval()
103
  model_f0[key].to(device)
 
109
  model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
110
  rmvpe = RMVPE(model_path, is_half=False, device=device)
111
 
112
+ mel_fn_args_f0 = {
113
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
114
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
115
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
116
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
117
+ "sampling_rate": sr,
118
+ "fmin": 0,
119
+ "fmax": None,
120
+ "center": False
121
+ }
122
+ to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
123
+ bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
124
+
125
+ # remove weight norm in the model and set to eval mode
126
+ bigvgan_44k_model.remove_weight_norm()
127
+ bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
128
+
129
  def adjust_f0_semitones(f0_sequence, n_semitones):
130
  factor = 2 ** (n_semitones / 12)
131
  return f0_sequence * factor
132
 
133
  def crossfade(chunk1, chunk2, overlap):
134
+ fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
135
+ fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
136
  chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
137
  return chunk2
138
 
139
  # streaming and chunk processing related params
140
  max_context_window = sr // hop_length * 30
141
+ overlap_frame_len = 16
142
  overlap_wave_len = overlap_frame_len * hop_length
143
  bitrate = "320k"
144
 
 
145
  @torch.no_grad()
146
  @torch.inference_mode()
147
  def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
148
  inference_module = model if not f0_condition else model_f0
149
  mel_fn = to_mel if not f0_condition else to_mel_f0
150
+ bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
151
+ sr = 22050 if not f0_condition else 44100
152
  # Load audio
153
  source_audio = librosa.load(source, sr=sr)[0]
154
  ref_audio = librosa.load(target, sr=sr)[0]
 
159
 
160
  # Resample
161
  ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
162
+ converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
163
+ # if source audio less than 30 seconds, whisper can handle in one forward
164
+ if converted_waves_16k.size(-1) <= 16000 * 30:
165
+ alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
166
+ return_tensors="pt",
167
+ return_attention_mask=True,
168
+ sampling_rate=16000)
169
+ alt_input_features = whisper_model._mask_input_features(
170
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
171
+ alt_outputs = whisper_model.encoder(
172
+ alt_input_features.to(whisper_model.encoder.dtype),
173
+ head_mask=None,
174
+ output_attentions=False,
175
+ output_hidden_states=False,
176
+ return_dict=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  )
178
+ S_alt = alt_outputs.last_hidden_state.to(torch.float32)
179
+ S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
180
  else:
181
+ overlapping_time = 5 # 5 seconds
182
+ S_alt_list = []
183
+ buffer = None
184
+ traversed_time = 0
185
+ while traversed_time < converted_waves_16k.size(-1):
186
+ if buffer is None: # first chunk
187
+ chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
188
+ else:
189
+ chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
190
+ alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
191
  return_tensors="pt",
192
  return_attention_mask=True,
193
  sampling_rate=16000)
 
201
  return_dict=True,
202
  )
203
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
204
+ S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
205
+ if traversed_time == 0:
206
+ S_alt_list.append(S_alt)
207
+ else:
208
+ S_alt_list.append(S_alt[:, 50 * overlapping_time:])
209
+ buffer = chunk[:, -16000 * overlapping_time:]
210
+ traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
211
+ S_alt = torch.cat(S_alt_list, dim=1)
212
+
213
+ ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
214
+ ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
215
+ return_tensors="pt",
216
+ return_attention_mask=True)
217
+ ori_input_features = whisper_model._mask_input_features(
218
+ ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
219
+ with torch.no_grad():
220
+ ori_outputs = whisper_model.encoder(
221
+ ori_input_features.to(whisper_model.encoder.dtype),
222
+ head_mask=None,
223
+ output_attentions=False,
224
+ output_hidden_states=False,
225
+ return_dict=True,
226
+ )
227
+ S_ori = ori_outputs.last_hidden_state.to(torch.float32)
228
+ S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  mel = mel_fn(source_audio.to(device).float())
231
  mel2 = mel_fn(ref_audio.to(device).float())
 
241
  style2 = campplus_model(feat2.unsqueeze(0))
242
 
243
  if f0_condition:
244
+ F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
245
+ F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
 
 
246
 
247
  F0_ori = torch.from_numpy(F0_ori).to(device)[None]
248
  F0_alt = torch.from_numpy(F0_alt).to(device)[None]
 
255
  voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
256
  median_log_f0_ori = torch.median(voiced_log_f0_ori)
257
  median_log_f0_alt = torch.median(voiced_log_f0_alt)
 
 
258
 
259
  # shift alt log f0 level to ori log f0 level
260
  shifted_log_f0_alt = log_f0_alt.clone()
 
287
  mel2, style2, None, diffusion_steps,
288
  inference_cfg_rate=inference_cfg_rate)
289
  vc_target = vc_target[:, :, mel2.size(-1):]
290
+ vc_wave = bigvgan_fn(vc_target)[0]
291
  if processed_frames == 0:
292
  if is_last_chunk:
293
  output_wave = vc_wave[0].cpu().numpy()