thecuong commited on
Commit
5af50a4
1 Parent(s): c87a1ad

feat: update

Browse files
Files changed (1) hide show
  1. app.py +20 -59
app.py CHANGED
@@ -1,65 +1,26 @@
1
- from typing import List, Literal
2
- from pydantic import BaseModel, Field
3
  import gradio as gr
4
- from fastapi import FastAPI, APIRouter, Request
5
- from fastapi.middleware.cors import CORSMiddleware
6
  from sentence_transformers import SentenceTransformer
7
- import uvicorn
8
- import requests
9
- import asyncio
10
- import threading
11
-
12
- # Khởi tạo FastAPI
13
- app = FastAPI()
14
-
15
- # Thêm middleware CORS để cho phép yêu cầu từ Gradio
16
- app.add_middleware(
17
- CORSMiddleware,
18
- allow_origins=["*"], # Cho phép tất cả các nguồn
19
- allow_credentials=True,
20
- allow_methods=["*"],
21
- allow_headers=["*"],
22
- )
23
 
24
  # Tải mô hình
25
- model = SentenceTransformer(model_name_or_path='Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
26
-
27
- # Định nghĩa mô hình dữ liệu cho yêu cầu
28
- class PostEmbeddings(BaseModel):
29
- type: Literal['default', 'disease', 'gte'] = Field(default='default')
30
- sentences: List[str]
31
-
32
- # Tạo router cho API
33
- router = APIRouter(
34
- prefix="/retrieval",
35
- tags=["retrieval"],
36
- responses={404: {"description": "Not found"}},
 
 
 
 
 
 
37
  )
38
 
39
- @app.post("/retrieval/embeddings")
40
- def post_embeddings(data: PostEmbeddings):
41
- embeddings = model.encode(data.sentences)
42
- return {
43
- 'data': {
44
- 'embeddings': embeddings.tolist(),
45
- 'type': data.type
46
- }
47
- }
48
-
49
- # Hàm Gradio để gọi API FastAPI
50
-
51
-
52
- # async def run_gradio():
53
- # demo.launch(share=True)
54
-
55
- async def run_uvicorn():
56
- config = uvicorn.Config("app:app", host="0.0.0.0", port=8000, reload=True)
57
- server = uvicorn.Server(config)
58
- await server.serve()
59
-
60
- # async def main():
61
- # await asyncio.gather(run_uvicorn(), run_gradio())
62
-
63
- # Khởi động server
64
- if __name__ == "__main__":
65
- asyncio.run(run_uvicorn())
 
 
 
1
  import gradio as gr
 
 
2
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Tải mô hình
5
+ model = SentenceTransformer(model_name_or_path='Alibaba-NLP/gte-multilingual-base',
6
+ trust_remote_code=True)
7
+
8
+ def gte_model(sentences: list):
9
+ try:
10
+ # Mã hóa các câu
11
+ embeddings = model.encode(sentences)
12
+ return embeddings.tolist() # Chuyển đổi numpy array sang danh sách
13
+ except Exception as e:
14
+ return f"Error: {str(e)}"
15
+
16
+ # Tạo giao diện Gradio
17
+ demo = gr.Interface(
18
+ fn=gte_model,
19
+ inputs=gr.inputs.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."),
20
+ outputs=gr.outputs.JSON(label="Kết quả mã hóa"),
21
+ title="Mô hình GTE Multilingual",
22
+ description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual. Kết quả sẽ được trả về dưới dạng danh sách mã hóa."
23
  )
24
 
25
+ # Khởi chạy giao diện
26
+ demo.launch()