HarryHe raoyonghui commited on
Commit
8d25678
1 Parent(s): 36f9ba6

Update app.py (#2)

Browse files

- Update app.py (dc256a6700d930bfb04f1474ecb3348ce431614a)


Co-authored-by: Yonghui Rao <[email protected]>

Files changed (1) hide show
  1. app.py +25 -10
app.py CHANGED
@@ -28,11 +28,21 @@ device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
28
 
29
  whisper_model = whisper.load_model("turbo")
30
 
31
- def detect_speech_language(speech_16k):
32
- return whisper_model.detect_language(speech_16k)
 
 
 
 
 
 
 
 
 
 
33
 
34
  def detect_text_language(text):
35
- langid.classify(text)[0]
36
 
37
  @torch.no_grad()
38
  def get_prompt_text(speech_16k, language):
@@ -41,7 +51,6 @@ def get_prompt_text(speech_16k, language):
41
  short_prompt_end_ts = 0.0
42
 
43
  asr_result = whisper_model.transcribe(speech_16k, language=language)
44
- print("asr_result:", asr_result)
45
  full_prompt_text = asr_result["text"] # whisper asr result
46
  #text = asr_result["segments"][0]["text"] # whisperx asr result
47
  shot_prompt_text = ""
@@ -51,8 +60,6 @@ def get_prompt_text(speech_16k, language):
51
  short_prompt_end_ts = segment['end']
52
  if short_prompt_end_ts >= 4:
53
  break
54
- print("full prompt text:", full_prompt_text, " shot_prompt_text:", shot_prompt_text,
55
- "short_prompt_end_ts:", short_prompt_end_ts)
56
  return full_prompt_text, shot_prompt_text, short_prompt_end_ts
57
 
58
 
@@ -310,7 +317,7 @@ def maskgct_inference(
310
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
311
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
312
 
313
- prompt_language = detect_speech_language(speech_16k)
314
  full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
315
  prompt_language)
316
  # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
@@ -321,7 +328,7 @@ def maskgct_inference(
321
  device,
322
  speech_16k,
323
  short_prompt_text,
324
- language,
325
  target_text,
326
  target_language,
327
  target_len,
@@ -393,9 +400,17 @@ iface = gr.Interface(
393
  outputs=gr.Audio(label="Generated Audio"),
394
  title="MaskGCT TTS Demo",
395
  description="""
396
- [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/tree/main/models/tts/maskgct)
 
 
 
 
 
 
 
 
397
  """
398
  )
399
 
400
  # Launch the interface
401
- iface.launch(allowed_paths=["./output"])
 
28
 
29
  whisper_model = whisper.load_model("turbo")
30
 
31
+ def detect_speech_language(speech_file):
32
+ # load audio and pad/trim it to fit 30 seconds
33
+ audio = whisper.load_audio(speech_file)
34
+ audio = whisper.pad_or_trim(audio)
35
+
36
+ # make log-Mel spectrogram and move to the same device as the model
37
+ mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(whisper_model.device)
38
+
39
+ # detect the spoken language
40
+ _, probs = whisper_model.detect_language(mel)
41
+ return max(probs, key=probs.get)
42
+
43
 
44
  def detect_text_language(text):
45
+ return langid.classify(text)[0]
46
 
47
  @torch.no_grad()
48
  def get_prompt_text(speech_16k, language):
 
51
  short_prompt_end_ts = 0.0
52
 
53
  asr_result = whisper_model.transcribe(speech_16k, language=language)
 
54
  full_prompt_text = asr_result["text"] # whisper asr result
55
  #text = asr_result["segments"][0]["text"] # whisperx asr result
56
  shot_prompt_text = ""
 
60
  short_prompt_end_ts = segment['end']
61
  if short_prompt_end_ts >= 4:
62
  break
 
 
63
  return full_prompt_text, shot_prompt_text, short_prompt_end_ts
64
 
65
 
 
317
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
318
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
319
 
320
+ prompt_language = detect_speech_language(prompt_speech_path)
321
  full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
322
  prompt_language)
323
  # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
 
328
  device,
329
  speech_16k,
330
  short_prompt_text,
331
+ prompt_language,
332
  target_text,
333
  target_language,
334
  target_len,
 
400
  outputs=gr.Audio(label="Generated Audio"),
401
  title="MaskGCT TTS Demo",
402
  description="""
403
+ ## MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer
404
+
405
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750)
406
+
407
+ [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct)
408
+
409
+ [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct)
410
+
411
+ [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/tree/main/models/tts/maskgct)
412
  """
413
  )
414
 
415
  # Launch the interface
416
+ iface.launch(allowed_paths=["./output"])