Spaces:
Running
Running
import torch | |
import torchaudio | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModel | |
import warnings | |
import traceback | |
import gc | |
warnings.filterwarnings("ignore") | |
class OptimizedContinuousTranslator: | |
def __init__(self, target_language="spa", chunk_duration=3, sample_rate=16000): | |
try: | |
self.processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large") | |
self.model = AutoModel.from_pretrained("facebook/seamless-m4t-v2-large") | |
self.target_language = target_language | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
self.processor = None | |
self.model = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def wav_to_tensor(self, file_path, sampling_rate): | |
""" | |
Converts a WAV file into a PyTorch tensor. | |
Args: | |
file_path (str): Path to the WAV file. | |
Returns: | |
torch.Tensor: Audio tensor. | |
int: Sampling rate of the audio. | |
""" | |
# Load the WAV file | |
waveform, sample_rate = torchaudio.load(file_path) | |
# Resample if the original sampling rate is not 16000 Hz | |
if sample_rate != sampling_rate: | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=sampling_rate) | |
waveform = resampler(waveform) | |
return waveform, sampling_rate | |
def translate_audio(self, audio_file_path): | |
""" | |
Enhanced audio translation with improved error handling and memory management | |
Args: | |
audio (torch.Tensor): Audio chunk to translate | |
Returns: | |
str: Translated text or error message | |
""" | |
print("REACHED") | |
if audio_file_path is None or self.processor is None or self.model is None: | |
print(f"{audio_file_path} {self.processor} {self.model}") | |
return "" | |
try: | |
# Prepare audio inputs | |
wavform, sample_rate = self.wav_to_tensor(audio_file_path, 16000) | |
audio_inputs = self.processor(audios=wavform.unsqueeze(0), return_tensors="pt", sampling_rate=sample_rate) | |
# Move inputs to the correct device | |
audio_inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v | |
for k, v in audio_inputs.items()} | |
# Generate translation | |
output_tokens = self.model.generate( | |
**audio_inputs, | |
tgt_lang=self.target_language, | |
generate_speech=False | |
) | |
# Decode the translated text | |
translated_text = self.processor.decode( | |
output_tokens[0].tolist()[0], | |
skip_special_tokens=True | |
) | |
print(translated_text) | |
return translated_text | |
except Exception as e: | |
error_message = f"Translation error: {str(e)}" | |
stack_trace = traceback.format_exc() | |
print(f"{error_message}\n{stack_trace}") | |
return "" | |
finally: | |
# Aggressive memory cleanup | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# web app | |
# simple translator (no real time) | |
def create_translator_interface(): | |
"""Create the optimized Gradio interface for the Continuous Translator""" | |
# Initialize the translator | |
translator = OptimizedContinuousTranslator() | |
with gr.Blocks(title="Continuous Audio Translator") as demo: | |
# Usage Instructions in a Markdown Dropdown | |
gr.Markdown(""" | |
## ๐๏ธ Audio Translator: How to Use | |
<details> | |
<summary>Click to view usage instructions</summary> | |
### ๐ Translation Steps | |
1. **Select Target Language**: | |
- Choose the language you want to translate to from the dropdown menu | |
2. **Record Audio**: | |
- Click on the microphone icon in the audio input area | |
- Record your audio clearly and concisely | |
- Ensure minimal background noise for best results | |
3. **Translate**: | |
- After recording, click the "Translate" button | |
- The translated text will appear in the transcript box below | |
### ๐ก Tips | |
- Speak clearly and at a moderate pace | |
- Avoid complex or technical language for more accurate translations | |
- The translation works best with shorter, simpler sentences | |
- Maximum recommended recording time is around 30 seconds | |
### ๐ Supported Languages | |
- Input: Currently supports clear spoken language | |
- Output: Any of the languages you choose from | |
</details> | |
""") | |
languages = { | |
"afr": "Afrikaans", | |
"amh": "Amharic", | |
"arb": "Modern Standard Arabic", | |
"ary": "Moroccan Arabic", | |
"arz": "Egyptian Arabic", | |
"asm": "Assamese", | |
"ast": "Asturian", | |
"azj": "North Azerbaijani", | |
"bel": "Belarusian", | |
"ben": "Bengali", | |
"bos": "Bosnian", | |
"bul": "Bulgarian", | |
"cat": "Catalan", | |
"ceb": "Cebuano", | |
"ces": "Czech", | |
"ckb": "Central Kurdish", | |
"cmn": "Mandarin Chinese", | |
"cmn_Hant": "Mandarin Chinese (Traditional)", | |
"cym": "Welsh", | |
"dan": "Danish", | |
"deu": "German", | |
"ell": "Greek", | |
"eng": "English", | |
"est": "Estonian", | |
"eus": "Basque", | |
"fin": "Finnish", | |
"fra": "French", | |
"fuv": "Nigerian Fulfulde", | |
"gaz": "West Central Oromo", | |
"gle": "Irish", | |
"glg": "Galician", | |
"guj": "Gujarati", | |
"heb": "Hebrew", | |
"hin": "Hindi", | |
"hrv": "Croatian", | |
"hun": "Hungarian", | |
"hye": "Armenian", | |
"ibo": "Igbo", | |
"ind": "Indonesian", | |
"isl": "Icelandic", | |
"ita": "Italian", | |
"jav": "Javanese", | |
"jpn": "Japanese", | |
"kam": "Kamba", | |
"kan": "Kannada", | |
"kat": "Georgian", | |
"kaz": "Kazakh", | |
"kea": "Kabuverdianu", | |
"khk": "Halh Mongolian", | |
"khm": "Khmer", | |
"kir": "Kyrgyz", | |
"kor": "Korean", | |
"lao": "Lao", | |
"lit": "Lithuanian", | |
"ltz": "Luxembourgish", | |
"lug": "Ganda", | |
"luo": "Luo", | |
"lvs": "Standard Latvian", | |
"mai": "Maithili", | |
"mal": "Malayalam", | |
"mar": "Marathi", | |
"mkd": "Macedonian", | |
"mlt": "Maltese", | |
"mni": "Meitei", | |
"mya": "Burmese", | |
"nld": "Dutch", | |
"nno": "Norwegian Nynorsk", | |
"nob": "Norwegian Bokmรฅl", | |
"npi": "Nepali", | |
"nya": "Nyanja", | |
"oci": "Occitan", | |
"ory": "Odia", | |
"pan": "Punjabi", | |
"pbt": "Southern Pashto", | |
"pes": "Western Persian", | |
"pol": "Polish", | |
"por": "Portuguese", | |
"ron": "Romanian", | |
"rus": "Russian", | |
"slk": "Slovak", | |
"slv": "Slovenian", | |
"sna": "Shona", | |
"snd": "Sindhi", | |
"som": "Somali", | |
"spa": "Spanish", | |
"srp": "Serbian", | |
"swe": "Swedish", | |
"swh": "Swahili", | |
"tam": "Tamil", | |
"tel": "Telugu", | |
"tgk": "Tajik", | |
"tgl": "Tagalog", | |
"tha": "Thai", | |
"tur": "Turkish", | |
"ukr": "Ukrainian", | |
"urd": "Urdu", | |
"uzn": "Northern Uzbek", | |
"vie": "Vietnamese", | |
"xho": "Xhosa", | |
"yor": "Yoruba", | |
"yue": "Cantonese", | |
"zlm": "Colloquial Malay", | |
"zsm": "Standard Malay", | |
"zul": "Zulu", | |
} | |
# Language Dropdown | |
with gr.Row(): | |
# Generate the choices for the dropdown: display names mapped to their keys | |
language_choices = [(name, code) for code, name in languages.items()] | |
language_dropdown = gr.Dropdown( | |
choices=language_choices, # Each choice is a (display, value) tuple | |
value="spa", # Default value corresponds to the key | |
label="Target Language", | |
scale=2 | |
) | |
# Audio Input | |
audio_input = gr.Audio(label="Record Audio", sources="microphone", type="filepath") | |
# Display Components | |
transcript_box = gr.Textbox(label="Full Transcript", lines=10, interactive=False) | |
# Control Buttons | |
with gr.Row(): | |
start_btn = gr.Button("Translate") | |
# Define the translation action | |
def handle_translation(audio_file, target_language): | |
"""Handle the audio file and pass it to the translator for processing.""" | |
if not audio_file: | |
return "No audio file provided. Please record and try again." | |
translator.target_language = target_language # Set the target language in the translator | |
try: | |
translated_text = translator.translate_audio(audio_file) | |
return translated_text if translated_text else "Translation failed." | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Set the Gradio action | |
start_btn.click( | |
fn=handle_translation, | |
inputs=[audio_input, language_dropdown], | |
outputs=transcript_box | |
) | |
return demo | |
def main(): | |
"""Launch the Gradio app with optimized settings""" | |
interface = create_translator_interface() | |
interface.launch( | |
share=False, | |
show_error=True, | |
debug=True # Helpful for development | |
) | |
if __name__ == "__main__": | |
main() |