API / app.py
Tmeena's picture
Update app.py
f14dcf1 verified
import base64
import os
from flask import Flask, request, jsonify
from pydub import AudioSegment
import whisper
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
# Define cache directory
os.environ['HF_HOME'] = '/app/cache'
# Load the Whisper model with a specified cache directory
whisper_model = whisper.load_model("turbo", download_root="/app/cache")
# Load the translation model and tokenizer
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache")
translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache")
def preprocess_audio(audio_path):
"""Convert audio to 16kHz mono WAV format."""
audio = AudioSegment.from_file(audio_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Set to 16kHz and mono
processed_dir = "/app/temp"
# Ensure the directory exists
os.makedirs(processed_dir, exist_ok=True)
# Define the path for the processed audio
processed_path = os.path.join(processed_dir, "processed.wav")
audio.export(processed_path, format="wav")
return processed_path
def transcribe_audio(audio_path, source_language=None):
"""Transcribe audio using Whisper with an optional source language."""
options = {"language": source_language} if source_language else {}
result = whisper_model.transcribe(audio_path, **options)
return result['text']
def translate_text(text, source_lang="en", target_lang="hi"):
"""Translate text using Facebook's M2M100 model."""
tokenizer.src_lang = source_lang
inputs = tokenizer(text, return_tensors="pt")
translated_tokens = translation_model.generate(
**inputs,
forced_bos_token_id=tokenizer.get_lang_id(target_lang)
)
return tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
def handle_request(audio_base64, source_lang, target_lang):
"""Handle audio translation request."""
temp_dir = "/app/temp"
os.makedirs(temp_dir, exist_ok=True) # Ensure directory exists
audio_file_path = os.path.join(temp_dir, "temp_audio.wav")
# Decode the base64 audio
with open(audio_file_path, "wb") as audio_file:
audio_file.write(base64.b64decode(audio_base64))
# Process the audio file
processed_audio_file_name = preprocess_audio(audio_file_path)
spoken_text = transcribe_audio(processed_audio_file_name, source_lang)
translated_text = translate_text(spoken_text, source_lang, target_lang)
# Clean up temporary files
os.remove(processed_audio_file_name)
os.remove(audio_file_path)
return {"transcribed_text": spoken_text, "translated_text": translated_text}
# Flask for handling external POST requests
app = Flask(__name__)
@app.route('/translate', methods=['POST'])
def translate():
"""API endpoint for handling audio translation."""
data = request.json
if 'audio' not in data or 'source_lang' not in data or 'target_lang' not in data:
return jsonify({"error": "Invalid request format"}), 400
audio_base64 = data['audio']
source_lang = data['source_lang']
target_lang = data['target_lang']
# Call the handle_request function to process the request
response = handle_request(audio_base64, source_lang, target_lang)
return jsonify(response)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)