import fastapi import uvicorn from fastapi import File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse, FileResponse from load_models import get_nllb_model_and_tokenizer, get_xtts_model from inference_functions import translate, just_inference import os import torch # Set GPU memory fraction torch.cuda.set_per_process_memory_fraction(0.75, 0) # Load models model_nllb, tokenizer_nllb = get_nllb_model_and_tokenizer() model_xtts = get_xtts_model() app = fastapi.FastAPI() @app.get("/health") def health_check(): return {"status": "ok"} @app.post("/translate/") def translate_text(text: str = Form(...), target_lang: str = Form(...)): translation = translate(model_nllb, tokenizer_nllb, text, target_lang) return {"translation": translation} @app.post("/inference/") def inference_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...)): # Save the uploaded file file_location = f"/tmp/{original_path.filename}" with open(file_location, "wb") as file: file.write(original_path.file.read()) output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav" torch.cuda.empty_cache() generated_audio = just_inference(model_xtts, file_location, output_dir, text, lang) return {"path_to_save": output_dir} @app.post("/process-audio/") async def process_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...), target_lang: str = Form(...)): print(f"original_path: {original_path.filename}") print(f"text: {text}") print(f"lang: {lang}") print(f"target_lang: {target_lang}") # Validate target language if target_lang not in ["es", "en"]: # Use 'es' and 'en' to match the example values print("Unsupported language") raise HTTPException(status_code=400, detail="Unsupported language. Use 'spanish' or 'english'.") try: # Translate the text first translated_text = translate(model_nllb, tokenizer_nllb, text, target_lang) print(f"translated_text: {translated_text}") # Save the uploaded file file_location = f"/tmp/{original_path.filename}" with open(file_location, "wb") as file: file.write(original_path.file.read()) output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav" torch.cuda.empty_cache() generated_audio = just_inference(model_xtts, file_location, output_dir, translated_text, target_lang) return JSONResponse(content={"audio_path": output_dir, "translation": translated_text}) except Exception as e: print(f"Error during processing: {e}") raise HTTPException(status_code=500, detail="Error during processing") @app.get("/download-audio/") def download_audio(file_path: str): if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="File not found") return FileResponse(file_path, media_type='audio/wav', filename=os.path.basename(file_path)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)