raoyonghui commited on
Commit
a8db66d
1 Parent(s): 0491f05

auto detect prompt language and text

Browse files
Files changed (1) hide show
  1. app.py +40 -24
app.py CHANGED
@@ -19,23 +19,44 @@ from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
 
22
- # import whisperx
23
 
24
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
25
 
26
  device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
 
27
 
28
- # whisper_model = whisperx.load_model("small", "cuda", compute_type="int8")
 
 
 
 
29
 
30
- # @torch.no_grad()
31
- # def get_prompt_text(speech_16k):
32
- # asr_result = whisper_model.transcribe(speech_16k)
33
- # print("asr_result:", asr_result)
34
- # language = asr_result["language"]
35
- # #text = asr_result["text"] # whisper asr result
36
- # text = asr_result["segments"][0]["text"]
37
- # print("prompt text:", text)
38
- # return text, language
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def g2p_(text, language):
@@ -279,9 +300,7 @@ def load_models():
279
  @torch.no_grad()
280
  def maskgct_inference(
281
  prompt_speech_path,
282
- prompt_text,
283
  target_text,
284
- language="en",
285
  target_language="en",
286
  target_len=None,
287
  n_timesteps=25,
@@ -295,14 +314,17 @@ def maskgct_inference(
295
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
296
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
297
 
298
- # if prompt_text is None:
299
- # prompt_text, language = get_prompt_text(prompt_speech_path)
300
-
 
 
 
301
  combine_semantic_code, _ = text2semantic(
302
  device,
303
  speech_16k,
304
- prompt_text,
305
- language,
306
  target_text,
307
  target_language,
308
  target_len,
@@ -326,20 +348,16 @@ def maskgct_inference(
326
  @spaces.GPU
327
  def inference(
328
  prompt_wav,
329
- prompt_text,
330
  target_text,
331
  target_len,
332
  n_timesteps,
333
- language,
334
  target_language,
335
  ):
336
  save_path = "./output/output.wav"
337
  os.makedirs("./output", exist_ok=True)
338
  recovered_audio = maskgct_inference(
339
  prompt_wav,
340
- prompt_text,
341
  target_text,
342
- language,
343
  target_language,
344
  target_len=target_len,
345
  n_timesteps=int(n_timesteps),
@@ -369,7 +387,6 @@ iface = gr.Interface(
369
  fn=inference,
370
  inputs=[
371
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
372
- gr.Textbox(label="Prompt Text"),
373
  gr.Textbox(label="Target Text"),
374
  gr.Number(
375
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
@@ -377,7 +394,6 @@ iface = gr.Interface(
377
  gr.Slider(
378
  label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
379
  ),
380
- gr.Dropdown(label="Language", choices=language_list, value="en"),
381
  gr.Dropdown(label="Target Language", choices=language_list, value="en"),
382
  ],
383
  outputs=gr.Audio(label="Generated Audio"),
 
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
 
22
+ import whisper
23
 
24
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
25
 
26
  device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
27
+ whisper_model = whisper.load_model("turbo")
28
 
29
+ def detect_speech_language(speech_file):
30
+ # load audio and pad/trim it to fit 30 seconds
31
+ whisper_model = whisper.load_model("turbo")
32
+ audio = whisper.load_audio(speech_file)
33
+ audio = whisper.pad_or_trim(audio)
34
 
35
+ # make log-Mel spectrogram and move to the same device as the model
36
+ mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(whisper_model.device)
37
+
38
+ # detect the spoken language
39
+ _, probs = whisper_model.detect_language(mel)
40
+ return max(probs, key=probs.get)
41
+
42
+
43
+ @torch.no_grad()
44
+ def get_prompt_text(speech_16k, language):
45
+ full_prompt_text = ""
46
+ shot_prompt_text = ""
47
+ short_prompt_end_ts = 0.0
48
+
49
+ asr_result = whisper_model.transcribe(speech_16k, language=language)
50
+ full_prompt_text = asr_result["text"] # whisper asr result
51
+ #text = asr_result["segments"][0]["text"] # whisperx asr result
52
+ shot_prompt_text = ""
53
+ short_prompt_end_ts = 0.0
54
+ for segment in asr_result["segments"]:
55
+ shot_prompt_text = shot_prompt_text + segment['text']
56
+ short_prompt_end_ts = segment['end']
57
+ if short_prompt_end_ts >= 4:
58
+ break
59
+ return full_prompt_text, shot_prompt_text, short_prompt_end_ts
60
 
61
 
62
  def g2p_(text, language):
 
300
  @torch.no_grad()
301
  def maskgct_inference(
302
  prompt_speech_path,
 
303
  target_text,
 
304
  target_language="en",
305
  target_len=None,
306
  n_timesteps=25,
 
314
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
315
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
316
 
317
+ prompt_language = detect_speech_language(prompt_speech_path)
318
+ full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
319
+ prompt_language)
320
+ # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
321
+ speech = speech[0: int(shot_prompt_end_ts * 24000)]
322
+ speech_16k = speech_16k[0: int(shot_prompt_end_ts*16000)]
323
  combine_semantic_code, _ = text2semantic(
324
  device,
325
  speech_16k,
326
+ short_prompt_text,
327
+ prompt_language,
328
  target_text,
329
  target_language,
330
  target_len,
 
348
  @spaces.GPU
349
  def inference(
350
  prompt_wav,
 
351
  target_text,
352
  target_len,
353
  n_timesteps,
 
354
  target_language,
355
  ):
356
  save_path = "./output/output.wav"
357
  os.makedirs("./output", exist_ok=True)
358
  recovered_audio = maskgct_inference(
359
  prompt_wav,
 
360
  target_text,
 
361
  target_language,
362
  target_len=target_len,
363
  n_timesteps=int(n_timesteps),
 
387
  fn=inference,
388
  inputs=[
389
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
 
390
  gr.Textbox(label="Target Text"),
391
  gr.Number(
392
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
 
394
  gr.Slider(
395
  label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
396
  ),
 
397
  gr.Dropdown(label="Target Language", choices=language_list, value="en"),
398
  ],
399
  outputs=gr.Audio(label="Generated Audio"),