IlyasMoutawwakil's picture
Create handler.py
575fb42 verified
raw
history blame contribute delete
No virus
1.51 kB
from typing import Any, Dict, List
from fastrag.rankers import QuantizedBiEncoderRanker
from fastrag.retrievers import QuantizedBiEncoderRetriever
from haystack import Pipeline
from haystack.document_stores import InMemoryDocumentStore
from haystack.schema import Document
class EndpointHandler:
def __init__(self, path=""):
EXAMPLES = [
"There is a blue house on Oxford Street.",
"Paris is the capital of France.",
"The first commit in fastRAG was in 2022",
]
document_store = InMemoryDocumentStore(use_gpu=False, use_bm25=False, embedding_dim=384, return_embedding=True)
documents = []
for i, d in enumerate(EXAMPLES):
documents.append(Document(content=d, id=i))
document_store.write_documents(documents)
model_id = "Intel/bge-small-en-v1.5-rag-int8-static"
retriever = QuantizedBiEncoderRetriever(document_store=document_store, embedding_model=model_id)
document_store.update_embeddings(retriever=retriever)
ranker = QuantizedBiEncoderRanker("Intel/bge-large-en-v1.5-rag-int8-static")
self.pipe = Pipeline()
self.pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
self.pipe.add_node(component=ranker, name="ranker", inputs=["retriever"])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
query = data.pop("inputs", data)
results = self.pipe.run(query=query)
return results