Spaces:
Sleeping
Sleeping
from typing import List, Literal | |
from pydantic import BaseModel, Field | |
from fastapi import FastAPI, APIRouter, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from sentence_transformers import SentenceTransformer | |
import uvicorn | |
# Initialize FastAPI app | |
app = FastAPI() | |
# CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load model | |
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True) | |
# Define data model | |
class PostEmbeddings(BaseModel): | |
type: Literal['default', 'disease', 'gte'] = Field(default='default') | |
sentences: List[str] | |
# Router for embeddings | |
router = APIRouter(prefix="/retrieval", tags=["retrieval"]) | |
def post_embeddings(request: Request, data: PostEmbeddings): | |
embeddings = model.encode(data.sentences) | |
return {"data":{"embeddings": embeddings.tolist()}} | |
# Include router | |
app.include_router(router) | |
# Define main function to run the app | |
def main(): | |
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) | |
# Run the app if this script is the main module | |
if __name__ == "__main__": | |
main() | |