File size: 3,110 Bytes
aa7cb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import fastapi
import uvicorn
from fastapi import File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from load_models import get_nllb_model_and_tokenizer, get_xtts_model
from inference_functions import translate, just_inference
import os
import torch

# Set GPU memory fraction
torch.cuda.set_per_process_memory_fraction(0.75, 0)

# Load models
model_nllb, tokenizer_nllb = get_nllb_model_and_tokenizer()
model_xtts = get_xtts_model()

app = fastapi.FastAPI()

@app.get("/health")
def health_check():
    return {"status": "ok"}

@app.post("/translate/")
def translate_text(text: str = Form(...), target_lang: str = Form(...)):
    translation = translate(model_nllb, tokenizer_nllb, text, target_lang)
    return {"translation": translation}

@app.post("/inference/")
def inference_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...)):
    # Save the uploaded file
    file_location = f"/tmp/{original_path.filename}"
    with open(file_location, "wb") as file:
        file.write(original_path.file.read())
    
    output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav"
    torch.cuda.empty_cache()
    generated_audio = just_inference(model_xtts, file_location, output_dir, text, lang)
    return {"path_to_save": output_dir}

@app.post("/process-audio/")
async def process_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...), target_lang: str = Form(...)):
    print(f"original_path: {original_path.filename}")
    print(f"text: {text}")
    print(f"lang: {lang}")
    print(f"target_lang: {target_lang}")

    # Validate target language
    if target_lang not in ["es", "en"]:  # Use 'es' and 'en' to match the example values
        print("Unsupported language")
        raise HTTPException(status_code=400, detail="Unsupported language. Use 'spanish' or 'english'.")

    try:
        # Translate the text first
        translated_text = translate(model_nllb, tokenizer_nllb, text, target_lang)
        print(f"translated_text: {translated_text}")

        # Save the uploaded file
        file_location = f"/tmp/{original_path.filename}"
        with open(file_location, "wb") as file:
            file.write(original_path.file.read())

        output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav"
        torch.cuda.empty_cache()
        generated_audio = just_inference(model_xtts, file_location, output_dir, translated_text, target_lang)

        return JSONResponse(content={"audio_path": output_dir, "translation": translated_text})

    except Exception as e:
        print(f"Error during processing: {e}")
        raise HTTPException(status_code=500, detail="Error during processing")

@app.get("/download-audio/")
def download_audio(file_path: str):
    if not os.path.exists(file_path):
        raise HTTPException(status_code=404, detail="File not found")
    return FileResponse(file_path, media_type='audio/wav', filename=os.path.basename(file_path))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)