from typing import List, Literal from pydantic import BaseModel, Field from fastapi import FastAPI, APIRouter, Request from fastapi.middleware.cors import CORSMiddleware from sentence_transformers import SentenceTransformer import uvicorn # Initialize FastAPI app app = FastAPI() # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True) # Define data model class PostEmbeddings(BaseModel): type: Literal['default', 'disease', 'gte'] = Field(default='default') sentences: List[str] # Router for embeddings router = APIRouter(prefix="/retrieval", tags=["retrieval"]) @router.post('/embeddings') def post_embeddings(request: Request, data: PostEmbeddings): embeddings = model.encode(data.sentences) return {"data":{"embeddings": embeddings.tolist()}} # Include router app.include_router(router) # Define main function to run the app def main(): uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) # Run the app if this script is the main module if __name__ == "__main__": main()