from modal import Image, App, Secret, web_endpoint, Volume, enter, method, build from typing import Dict import sys model_image = (Image.debian_slim(python_version="3.12") .pip_install("chromadb", "sentence-transformers", "pysqlite3-binary") ) # Utilities with model_image.imports(): import os import numpy as np __import__("pysqlite3") sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") # Hotswap SQLlite version # Application initialization app = App("mps-api", image=model_image) vol = Volume.from_name("mps", create_if_missing=False) data_path = "/data" ############ # MAIN CLASS ############ @app.cls(timeout=30*60, volumes={data_path: vol}) class VECTORDB: @enter() @build() def init(self): # Load encoder from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction model_name = "Lajavaness/sentence-camembert-large" self.embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name) print(f"Embedding model loaded: {model_name}") # Load vector database import chromadb DB_PATH = data_path + "/db" COLLECTION_NAME = "MPS" chroma_client = chromadb.PersistentClient(path=DB_PATH) self.chroma_collection = chroma_client.get_collection(name=COLLECTION_NAME, embedding_function=self.embedding_function) print(f"{self.chroma_collection.count()} documents loaded.") @method() def search(self, queries, origins, n_results=10): results = self.chroma_collection.query( query_texts=queries, n_results=n_results, where={"origin": {"$in": origins}}, include=['documents', 'metadatas', 'distances']) documents = results['documents'] metadatas = results['metadatas'] distances = results['distances'] return documents, metadatas, distances @app.cls(timeout=30*60) class RANKING: @enter() @build() def init(self): # Load crossencoder from sentence_transformers import CrossEncoder model_name = "Lajavaness/CrossEncoder-camembert-large" self.cross_encoder = CrossEncoder(model_name) print(f"Cross encoder model loaded: {model_name}") @method() def rank(self, query, documents): pairs = [[query, doc] for doc in documents] scores = self.cross_encoder.predict(pairs) ranking = np.argsort(scores)[::-1].tolist() return ranking ########### # ENDPOINTS ########### @app.function(timeout=30*60) @web_endpoint(method="POST") def retrieve(query: Dict): # Log query print(f"Retrieve query: {query}...") # Searching documents documents, metadatas, distances = VECTORDB().search.remote(query['query'], query['origins'], query['n_results']) return {"documents" : documents, "metadatas" : metadatas, "distances" : distances} @app.function(timeout=30*60) @web_endpoint(method="POST") def rank(query: Dict): # Log query print(f"Rank query: {query}...") # Ranking documents ranking = RANKING().rank.remote(query['query'], query['documents']) return {"ranking" : ranking}