import base64 import os from flask import Flask, request, jsonify from pydub import AudioSegment import whisper from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer # Define cache directory os.environ['HF_HOME'] = '/app/cache' # Load the Whisper model with a specified cache directory whisper_model = whisper.load_model("base", download_root="/app/cache") # Load the translation model and tokenizer tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache") translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache") def preprocess_audio(audio_path): """Convert audio to 16kHz mono WAV format.""" audio = AudioSegment.from_file(audio_path) audio = audio.set_frame_rate(16000).set_channels(1) # Set to 16kHz and mono processed_path = f"{audio_path}_processed.wav" audio.export(processed_path, format="wav") return processed_path def transcribe_audio(audio_path, source_language=None): """Transcribe audio using Whisper with an optional source language.""" options = {"language": source_language} if source_language else {} result = whisper_model.transcribe(audio_path, **options) return result['text'] def translate_text(text, source_lang="en", target_lang="hi"): """Translate text using Facebook's M2M100 model.""" tokenizer.src_lang = source_lang inputs = tokenizer(text, return_tensors="pt") translated_tokens = translation_model.generate( **inputs, forced_bos_token_id=tokenizer.get_lang_id(target_lang) ) return tokenizer.decode(translated_tokens[0], skip_special_tokens=True) def handle_request(audio_base64, source_lang, target_lang): """Handle audio translation request.""" temp_dir = "/app/temp" os.makedirs(temp_dir, exist_ok=True) # Ensure directory exists audio_file_path = os.path.join(temp_dir, "temp_audio.wav") # Decode the base64 audio with open(audio_file_path, "wb") as audio_file: audio_file.write(base64.b64decode(audio_base64)) # Process the audio file processed_audio_file_name = preprocess_audio(audio_file_path) spoken_text = transcribe_audio(processed_audio_file_name, source_lang) translated_text = translate_text(spoken_text, source_lang, target_lang) # Clean up temporary files os.remove(processed_audio_file_name) os.remove(audio_file_path) return {"transcribed_text": spoken_text, "translated_text": translated_text} # Flask for handling external POST requests app = Flask(__name__) @app.route('/translate', methods=['POST']) def translate(): """API endpoint for handling audio translation.""" data = request.json if 'audio' not in data or 'source_lang' not in data or 'target_lang' not in data: return jsonify({"error": "Invalid request format"}), 400 audio_base64 = data['audio'] source_lang = data['source_lang'] target_lang = data['target_lang'] # Call the handle_request function to process the request response = handle_request(audio_base64, source_lang, target_lang) return jsonify(response) if __name__ == "__main__": app.run(host='0.0.0.0', port=7860)