robinhad commited on
Commit
4d6d915
1 Parent(s): 04f46ca

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +122 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Depends, HTTPException, Request
2
+ from fastapi.security import APIKeyQuery
3
+ from pydantic import BaseModel
4
+ from typing import List, Union, Dict
5
+ from functools import lru_cache
6
+ import jwt
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
+ import torch
9
+ from flores200_codes import flores_codes
10
+ import gradio as gr
11
+
12
+ CUSTOM_PATH = "/gradio"
13
+
14
+ app = FastAPI()
15
+
16
+
17
+ # This should be a secure secret key in a real application
18
+ SECRET_KEY = "your_secret_key_here"
19
+
20
+ # Define the security scheme
21
+ api_key_query = APIKeyQuery(name="jwtToken", auto_error=False)
22
+
23
+
24
+ class TranslationRequest(BaseModel):
25
+ strings: List[Union[str, Dict[str, str]]]
26
+
27
+
28
+ class TranslationResponse(BaseModel):
29
+ data: Dict[str, List[str]]
30
+
31
+
32
+ @lru_cache()
33
+ def load_model():
34
+ model_name_dict = {
35
+ "nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
36
+ }
37
+
38
+ call_name = "nllb-distilled-600M"
39
+ real_name = model_name_dict[call_name]
40
+ print(f"\tLoading model: {call_name}")
41
+
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name).to(device)
44
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
45
+
46
+ return model, tokenizer
47
+
48
+
49
+ model, tokenizer = load_model()
50
+
51
+
52
+ def translate_text(text: List[str], source_lang: str, target_lang: str) -> List[str]:
53
+ source = flores_codes[source_lang]
54
+ target = flores_codes[target_lang]
55
+
56
+ translator = pipeline(
57
+ "translation",
58
+ model=model,
59
+ tokenizer=tokenizer,
60
+ src_lang=source,
61
+ tgt_lang=target,
62
+ )
63
+ output = translator(text, max_length=400)
64
+
65
+ return [item["translation_text"] for item in output]
66
+
67
+
68
+ async def verify_token(token: str = Depends(api_key_query)):
69
+ if not token:
70
+ raise HTTPException(status_code=401, detail={"message": "Token is missing"})
71
+ try:
72
+ jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
73
+ except:
74
+ raise HTTPException(status_code=401, detail={"message": "Token is invalid"})
75
+ return token
76
+
77
+
78
+ @app.post("/translate/", response_model=TranslationResponse)
79
+ async def translate(
80
+ request: Request,
81
+ source: str,
82
+ target: str,
83
+ project_id: str,
84
+ token: str = Depends(verify_token),
85
+ ):
86
+ if not all([source, target, project_id]):
87
+ raise HTTPException(
88
+ status_code=400, detail={"message": "Missing required parameters"}
89
+ )
90
+
91
+ data = await request.json()
92
+ strings = data.get("strings", [])
93
+
94
+ if not strings:
95
+ raise HTTPException(
96
+ status_code=400, detail={"message": "No strings provided for translation"}
97
+ )
98
+
99
+ try:
100
+ if isinstance(strings[0], dict): # Extended request
101
+ translations = translate_text([s["text"] for s in strings], source, target)
102
+ else: # Simple request
103
+ translations = translate_text(strings, source, target)
104
+
105
+ return TranslationResponse(data={"translations": translations})
106
+ except Exception as e:
107
+ raise HTTPException(status_code=500, detail={"message": str(e)})
108
+
109
+
110
+ @app.get("/logo.png")
111
+ async def logo():
112
+ # TODO: Implement logic to serve the logo
113
+ return "Logo placeholder"
114
+
115
+
116
+ io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
117
+ app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH)
118
+
119
+ if __name__ == "__main__":
120
+ import uvicorn
121
+
122
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ flask
2
+ jwt
3
+ transformers
4
+ torch