from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse # from sentence_transformers import SentenceTransformer # from sentence_transformers.util import cos_sim from typing import List import os, platform, time from transformers import AutoTokenizer import fastembed from fastembed import SparseEmbedding, SparseTextEmbedding, TextEmbedding import numpy as np sparse_model_name = "prithvida/Splade_PP_en_v1" sparse_model = SparseTextEmbedding(model_name=sparse_model_name, batch_size=32) class Validation(BaseModel): prompt: List[str] from etown_mxbai import app app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/api/generate", summary="Generate embeddings", tags=["Generate"]) def inference(item: Validation): try: start_time = time.time() embeddings = list(sparse_model.embed(item.prompt, batch_size=5)) # Assuming 'model' is defined elsewhere serializable_embeddings = [] for embedding in embeddings: # Assuming embedding object has attributes values and indices if isinstance(embedding, SparseEmbedding): values = embedding.values indices = embedding.indices serializable_embeddings.append({ "values": values.tolist() if isinstance(values, np.ndarray) else values, "indices": indices.tolist() if isinstance(indices, np.ndarray) else indices }) else: # Fallback for other types of embeddings serializable_embeddings.append({ "values": embedding.tolist() if isinstance(embedding, np.ndarray) else str(embedding), "indices": list(range(len(embedding))) if isinstance(embedding, (np.ndarray, list)) else [] }) end_time = time.time() time_taken = end_time - start_time # Calculate the time taken return JSONResponse(content={ "embeddings": serializable_embeddings, "time_taken": f"{time_taken:.2f} seconds", "Number_of_sentence_processed": len(item.prompt), # Assuming you want to count words, not characters "Model_response_space" : "prithvida/Splade_PP_en_v1", "status_code" : 200 }) except Exception as e: print(f"An error occurred: {str(e)}") # Simple print statement for logging; consider using proper logging return JSONResponse(content={ "error": "An error occurred during processing.", "details": str(e), "Model_response_space" : "prithvida/Splade_PP_en_v1", "status_code" : 500 })