|
import base64 |
|
import os |
|
from flask import Flask, request, jsonify |
|
from pydub import AudioSegment |
|
import whisper |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
|
|
|
|
os.environ['HF_HOME'] = '/app/cache' |
|
|
|
|
|
whisper_model = whisper.load_model("base", download_root="/app/cache") |
|
|
|
|
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache") |
|
translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache") |
|
|
|
def preprocess_audio(audio_path): |
|
"""Convert audio to 16kHz mono WAV format.""" |
|
audio = AudioSegment.from_file(audio_path) |
|
audio = audio.set_frame_rate(16000).set_channels(1) |
|
processed_path = f"{audio_path}_processed.wav" |
|
audio.export(processed_path, format="wav") |
|
return processed_path |
|
|
|
def transcribe_audio(audio_path, source_language=None): |
|
"""Transcribe audio using Whisper with an optional source language.""" |
|
options = {"language": source_language} if source_language else {} |
|
result = whisper_model.transcribe(audio_path, **options) |
|
return result['text'] |
|
|
|
def translate_text(text, source_lang="en", target_lang="hi"): |
|
"""Translate text using Facebook's M2M100 model.""" |
|
tokenizer.src_lang = source_lang |
|
inputs = tokenizer(text, return_tensors="pt") |
|
translated_tokens = translation_model.generate( |
|
**inputs, |
|
forced_bos_token_id=tokenizer.get_lang_id(target_lang) |
|
) |
|
return tokenizer.decode(translated_tokens[0], skip_special_tokens=True) |
|
|
|
def handle_request(audio_base64, source_lang, target_lang): |
|
"""Handle audio translation request.""" |
|
audio_file_path = "temp_audio.wav" |
|
|
|
with open(audio_file_path, "wb") as audio_file: |
|
audio_file.write(base64.b64decode(audio_base64)) |
|
|
|
|
|
processed_audio_file_name = preprocess_audio(audio_file_path) |
|
spoken_text = transcribe_audio(processed_audio_file_name, source_lang) |
|
translated_text = translate_text(spoken_text, source_lang, target_lang) |
|
|
|
|
|
os.remove(processed_audio_file_name) |
|
os.remove(audio_file_path) |
|
|
|
return {"transcribed_text": spoken_text, "translated_text": translated_text} |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
@app.route('/translate', methods=['POST']) |
|
def translate(): |
|
"""API endpoint for handling audio translation.""" |
|
data = request.json |
|
if 'audio' not in data or 'source_lang' not in data or 'target_lang' not in data: |
|
return jsonify({"error": "Invalid request format"}), 400 |
|
|
|
audio_base64 = data['audio'] |
|
source_lang = data['source_lang'] |
|
target_lang = data['target_lang'] |
|
|
|
|
|
response = handle_request(audio_base64, source_lang, target_lang) |
|
return jsonify(response) |
|
|
|
if __name__ == "__main__": |
|
app.run(host='0.0.0.0', port=7860) |
|
|