from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.security import APIKeyQuery from pydantic import BaseModel from typing import List, Union, Dict from functools import lru_cache import jwt from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import torch from flores200_codes import flores_codes import gradio as gr CUSTOM_PATH = "/gradio" app = FastAPI() # This should be a secure secret key in a real application SECRET_KEY = "your_secret_key_here" # Define the security scheme api_key_query = APIKeyQuery(name="jwtToken", auto_error=False) class TranslationRequest(BaseModel): strings: List[Union[str, Dict[str, str]]] class TranslationResponse(BaseModel): data: Dict[str, List[str]] @lru_cache() def load_model(): model_name_dict = { "nllb-distilled-600M": "facebook/nllb-200-distilled-600M", } call_name = "nllb-distilled-600M" real_name = model_name_dict[call_name] print(f"\tLoading model: {call_name}") device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForSeq2SeqLM.from_pretrained(real_name).to(device) tokenizer = AutoTokenizer.from_pretrained(real_name) return model, tokenizer model, tokenizer = load_model() def translate_text(text: List[str], source_lang: str, target_lang: str) -> List[str]: source = flores_codes[source_lang] target = flores_codes[target_lang] translator = pipeline( "translation", model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target, ) output = translator(text, max_length=400) return [item["translation_text"] for item in output] async def verify_token(token: str = Depends(api_key_query)): if not token: raise HTTPException(status_code=401, detail={"message": "Token is missing"}) try: jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) except: raise HTTPException(status_code=401, detail={"message": "Token is invalid"}) return token @app.post("/translate/", response_model=TranslationResponse) async def translate( request: Request, source: str, target: str, project_id: str, token: str = Depends(verify_token), ): if not all([source, target, project_id]): raise HTTPException( status_code=400, detail={"message": "Missing required parameters"} ) data = await request.json() strings = data.get("strings", []) if not strings: raise HTTPException( status_code=400, detail={"message": "No strings provided for translation"} ) try: if isinstance(strings[0], dict): # Extended request translations = translate_text([s["text"] for s in strings], source, target) else: # Simple request translations = translate_text(strings, source, target) return TranslationResponse(data={"translations": translations}) except Exception as e: raise HTTPException(status_code=500, detail={"message": str(e)}) @app.get("/logo.png") async def logo(): # TODO: Implement logic to serve the logo return "Logo placeholder" io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox") app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)