thecuong's picture
First commit
4bb4208
from typing import List, Literal
from pydantic import BaseModel, Field
from fastapi import FastAPI, APIRouter, Request
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
import uvicorn
# Initialize FastAPI app
app = FastAPI()
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
# Define data model
class PostEmbeddings(BaseModel):
type: Literal['default', 'disease', 'gte'] = Field(default='default')
sentences: List[str]
# Router for embeddings
router = APIRouter(prefix="/retrieval", tags=["retrieval"])
@router.post('/embeddings')
def post_embeddings(request: Request, data: PostEmbeddings):
embeddings = model.encode(data.sentences)
return {"data":{"embeddings": embeddings.tolist()}}
# Include router
app.include_router(router)
# Define main function to run the app
def main():
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
# Run the app if this script is the main module
if __name__ == "__main__":
main()