API / app.py
Tmeena's picture
Update app.py
2109067 verified
raw
history blame
3.09 kB
import base64
import os
from flask import Flask, request, jsonify
from pydub import AudioSegment
import whisper
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
# Define cache directory
os.environ['HF_HOME'] = '/app/cache'
# Load the Whisper model with a specified cache directory
whisper_model = whisper.load_model("base", download_root="/app/cache")
# Load the translation model and tokenizer
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache")
translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache")
def preprocess_audio(audio_path):
"""Convert audio to 16kHz mono WAV format."""
audio = AudioSegment.from_file(audio_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Set to 16kHz and mono
processed_path = f"{audio_path}_processed.wav"
audio.export(processed_path, format="wav")
return processed_path
def transcribe_audio(audio_path, source_language=None):
"""Transcribe audio using Whisper with an optional source language."""
options = {"language": source_language} if source_language else {}
result = whisper_model.transcribe(audio_path, **options)
return result['text']
def translate_text(text, source_lang="en", target_lang="hi"):
"""Translate text using Facebook's M2M100 model."""
tokenizer.src_lang = source_lang
inputs = tokenizer(text, return_tensors="pt")
translated_tokens = translation_model.generate(
**inputs,
forced_bos_token_id=tokenizer.get_lang_id(target_lang)
)
return tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
def handle_request(audio_base64, source_lang, target_lang):
"""Handle audio translation request."""
audio_file_path = "temp_audio.wav"
# Decode the base64 audio
with open(audio_file_path, "wb") as audio_file:
audio_file.write(base64.b64decode(audio_base64))
# Process the audio file
processed_audio_file_name = preprocess_audio(audio_file_path)
spoken_text = transcribe_audio(processed_audio_file_name, source_lang)
translated_text = translate_text(spoken_text, source_lang, target_lang)
# Clean up temporary files
os.remove(processed_audio_file_name)
os.remove(audio_file_path)
return {"transcribed_text": spoken_text, "translated_text": translated_text}
# Flask for handling external POST requests
app = Flask(__name__)
@app.route('/translate', methods=['POST'])
def translate():
"""API endpoint for handling audio translation."""
data = request.json
if 'audio' not in data or 'source_lang' not in data or 'target_lang' not in data:
return jsonify({"error": "Invalid request format"}), 400
audio_base64 = data['audio']
source_lang = data['source_lang']
target_lang = data['target_lang']
# Call the handle_request function to process the request
response = handle_request(audio_base64, source_lang, target_lang)
return jsonify(response)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)