umarigan commited on
Commit
33215f7
1 Parent(s): db969d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -92
app.py CHANGED
@@ -8,13 +8,13 @@ import tempfile
8
  import soundfile as sf
9
  import scipy.io.wavfile as wav
10
 
11
- from transformers import pipeline, VitsModel, AutoTokenizer, set_seed
12
  from nemo.collections.asr.models import EncDecMultiTaskModel
13
 
14
  # Constants
15
  SAMPLE_RATE = 16000 # Hz
16
 
17
- # load ASR model
18
  canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b')
19
  decode_cfg = canary_model.cfg.decoding
20
  decode_cfg.beam.beam_size = 1
@@ -34,13 +34,13 @@ def gen_text(audio_filepath, action, source_lang, target_lang):
34
  converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
35
  sf.write(converted_audio_filepath, data, SAMPLE_RATE)
36
 
37
- # Transcribe audio
38
  duration = len(data) / SAMPLE_RATE
39
  manifest_data = {
40
  "audio_filepath": converted_audio_filepath,
41
  "taskname": action,
42
  "source_lang": source_lang,
43
- "target_lang": source_lang if action=="asr" else target_lang,
44
  "pnc": "no",
45
  "answer": "predict",
46
  "duration": str(duration),
@@ -50,33 +50,13 @@ def gen_text(audio_filepath, action, source_lang, target_lang):
50
  fout.write(json.dumps(manifest_data))
51
 
52
  predicted_text = canary_model.transcribe(manifest_filepath)[0]
53
- # if duration < 40:
54
- # predicted_text = canary_model.transcribe(manifest_filepath)[0]
55
- # else:
56
- # predicted_text = get_buffered_pred_feat_multitaskAED(
57
- # frame_asr,
58
- # canary_model.cfg.preprocessor,
59
- # model_stride_in_secs,
60
- # canary_model.device,
61
- # manifest=manifest_filepath,
62
- # )[0].text
63
 
64
  return predicted_text
65
 
66
  # Function to convert text to speech using TTS
67
  def gen_speech(text, lang):
68
  set_seed(555) # Make it deterministic
69
- match lang:
70
- case "en":
71
- model = "facebook/mms-tts-eng"
72
- case "fr":
73
- model = "facebook/mms-tts-fra"
74
- case "de":
75
- model = "facebook/mms-tts-deu"
76
- case "es":
77
- model = "facebook/mms-tts-spa"
78
- case _:
79
- model = "facebook/mms-tts"
80
 
81
  # load TTS model
82
  tts_model = VitsModel.from_pretrained(model)
@@ -86,75 +66,54 @@ def gen_speech(text, lang):
86
  with torch.no_grad():
87
  outputs = tts_model(**input_text)
88
  waveform_np = outputs.waveform[0].cpu().numpy()
89
- output_file = f"{str(uuid.uuid4())}.wav"
90
- wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
91
- return output_file
92
-
93
- # Root function for Gradio interface
94
- def start_process(audio_filepath, source_lang, target_lang):
95
- transcription = gen_text(audio_filepath, "asr", source_lang, target_lang)
96
- print("Done transcribing")
97
- translation = gen_text(audio_filepath, "s2t_translation", source_lang, target_lang)
98
- print("Done translation")
99
- audio_output_filepath = gen_speech(translation, target_lang)
100
- print("Done speaking")
101
- return transcription, translation, audio_output_filepath
 
 
102
 
103
-
104
  # Create Gradio interface
105
- playground = gr.Blocks()
106
-
107
- with playground:
108
-
109
- with gr.Row():
110
- gr.Markdown("""
111
- ## Your AI Translate Assistant
112
- ### Gets input audio from user, transcribe and translate it. Convert back to speech.
113
- - category: [Automatic Speech Recognition](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition), model: [nvidia/canary-1b](https://huggingface.co/nvidia/canary-1b)
114
- - category: [Text-to-Speech](https://huggingface.co/models?pipeline_tag=text-to-speech), model: [facebook/mms-tts](https://huggingface.co/facebook/mms-tts)
115
- """)
116
-
117
- with gr.Row():
118
- with gr.Column():
119
- source_lang = gr.Dropdown(
120
- choices=["en", "de", "es", "fr"], value="en", label="Source Language"
121
- )
122
- with gr.Column():
123
- target_lang = gr.Dropdown(
124
- choices=["en", "de", "es", "fr"], value="fr", label="Target Language"
125
- )
126
 
127
- with gr.Row():
128
- with gr.Column():
129
- input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")
130
- with gr.Column():
131
- translated_speech = gr.Audio(type="filepath", label="Generated Speech")
132
-
133
- with gr.Row():
134
- with gr.Column():
135
- transcipted_text = gr.Textbox(label="Transcription")
136
- with gr.Column():
137
- translated_text = gr.Textbox(label="Translation")
138
-
139
- with gr.Row():
140
- with gr.Column():
141
- submit_button = gr.Button(value="Start Process", variant="primary")
142
- with gr.Column():
143
- clear_button = gr.ClearButton(components=[input_audio, source_lang, target_lang, transcipted_text, translated_text, translated_speech], value="Clear")
144
 
145
  with gr.Row():
146
- gr.Examples(
147
- examples=[
148
- ["sample_en.wav","en","fr"],
149
- ["sample_fr.wav","fr","de"],
150
- ["sample_de.wav","de","es"],
151
- ["sample_es.wav","es","en"]
152
- ],
153
- inputs=[input_audio, source_lang, target_lang],
154
- outputs=[transcipted_text, translated_text, translated_speech],
155
- run_on_click=True, cache_examples=True, fn=start_process
156
- )
157
-
158
- submit_button.click(start_process, inputs=[input_audio, source_lang, target_lang], outputs=[transcipted_text, translated_text, translated_speech])
159
-
160
- playground.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import soundfile as sf
9
  import scipy.io.wavfile as wav
10
 
11
+ from transformers import VitsModel, AutoTokenizer, set_seed
12
  from nemo.collections.asr.models import EncDecMultiTaskModel
13
 
14
  # Constants
15
  SAMPLE_RATE = 16000 # Hz
16
 
17
+ # Load ASR model
18
  canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b')
19
  decode_cfg = canary_model.cfg.decoding
20
  decode_cfg.beam.beam_size = 1
 
34
  converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
35
  sf.write(converted_audio_filepath, data, SAMPLE_RATE)
36
 
37
+ # Transcribe or translate audio
38
  duration = len(data) / SAMPLE_RATE
39
  manifest_data = {
40
  "audio_filepath": converted_audio_filepath,
41
  "taskname": action,
42
  "source_lang": source_lang,
43
+ "target_lang": source_lang if action == "asr" else target_lang,
44
  "pnc": "no",
45
  "answer": "predict",
46
  "duration": str(duration),
 
50
  fout.write(json.dumps(manifest_data))
51
 
52
  predicted_text = canary_model.transcribe(manifest_filepath)[0]
 
 
 
 
 
 
 
 
 
 
53
 
54
  return predicted_text
55
 
56
  # Function to convert text to speech using TTS
57
  def gen_speech(text, lang):
58
  set_seed(555) # Make it deterministic
59
+ model = f"facebook/mms-tts-{lang}"
 
 
 
 
 
 
 
 
 
 
60
 
61
  # load TTS model
62
  tts_model = VitsModel.from_pretrained(model)
 
66
  with torch.no_grad():
67
  outputs = tts_model(**input_text)
68
  waveform_np = outputs.waveform[0].cpu().numpy()
69
+ return SAMPLE_RATE, waveform_np
70
+
71
+ # Main function for speech-to-speech translation
72
+ def speech_to_speech_translation(audio_filepath, source_lang, target_lang):
73
+ translation = gen_text(audio_filepath, "s2t_translation", source_lang, target_lang)
74
+ sample_rate, synthesized_speech = gen_speech(translation, target_lang)
75
+ return sample_rate, synthesized_speech
76
+
77
+ # Define supported languages
78
+ LANGUAGES = {
79
+ "English": "eng",
80
+ "German": "deu",
81
+ "Spanish": "spa",
82
+ "French": "fra"
83
+ }
84
 
 
85
  # Create Gradio interface
86
+ demo = gr.Blocks()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ with demo:
89
+ gr.Markdown("# Multilingual Speech-to-Speech Translation")
90
+ gr.Markdown("Translate speech from one language to another.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  with gr.Row():
93
+ source_lang = gr.Dropdown(choices=list(LANGUAGES.keys()), value="English", label="Source Language")
94
+ target_lang = gr.Dropdown(choices=list(LANGUAGES.keys()), value="French", label="Target Language")
95
+
96
+ with gr.Tabs():
97
+ with gr.TabItem("Microphone"):
98
+ mic_input = gr.Audio(source="microphone", type="filepath")
99
+ mic_output = gr.Audio(label="Generated Speech", type="numpy")
100
+ mic_button = gr.Button("Translate")
101
+
102
+ with gr.TabItem("Audio File"):
103
+ file_input = gr.Audio(source="upload", type="filepath")
104
+ file_output = gr.Audio(label="Generated Speech", type="numpy")
105
+ file_button = gr.Button("Translate")
106
+
107
+ mic_button.click(
108
+ speech_to_speech_translation,
109
+ inputs=[mic_input, source_lang, target_lang],
110
+ outputs=mic_output
111
+ )
112
+
113
+ file_button.click(
114
+ speech_to_speech_translation,
115
+ inputs=[file_input, source_lang, target_lang],
116
+ outputs=file_output
117
+ )
118
+
119
+ demo.launch()