Zulelee's picture
Upload 254 files
5e9cd1d verified
raw
history blame
4.74 kB
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from typing import Any, List, Optional
from sentence_transformers import CrossEncoder
from typing import Optional, Sequence
from langchain_core.documents import Document
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from llama_index.bridge.pydantic import Field, PrivateAttr
class LangchainReranker(BaseDocumentCompressor):
"""Document compressor that uses `Cohere Rerank API`."""
model_name_or_path: str = Field()
_model: Any = PrivateAttr()
top_n: int = Field()
device: str = Field()
max_length: int = Field()
batch_size: int = Field()
# show_progress_bar: bool = None
num_workers: int = Field()
# activation_fct = None
# apply_softmax = False
def __init__(self,
model_name_or_path: str,
top_n: int = 3,
device: str = "cuda",
max_length: int = 1024,
batch_size: int = 32,
# show_progress_bar: bool = None,
num_workers: int = 0,
# activation_fct = None,
# apply_softmax = False,
):
# self.top_n=top_n
# self.model_name_or_path=model_name_or_path
# self.device=device
# self.max_length=max_length
# self.batch_size=batch_size
# self.show_progress_bar=show_progress_bar
# self.num_workers=num_workers
# self.activation_fct=activation_fct
# self.apply_softmax=apply_softmax
self._model = CrossEncoder(model_name=model_name_or_path, max_length=1024, device=device)
super().__init__(
top_n=top_n,
model_name_or_path=model_name_or_path,
device=device,
max_length=max_length,
batch_size=batch_size,
# show_progress_bar=show_progress_bar,
num_workers=num_workers,
# activation_fct=activation_fct,
# apply_softmax=apply_softmax
)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Cohere's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
if len(documents) == 0: # to avoid empty api call
return []
doc_list = list(documents)
_docs = [d.page_content for d in doc_list]
sentence_pairs = [[query, _doc] for _doc in _docs]
results = self._model.predict(sentences=sentence_pairs,
batch_size=self.batch_size,
# show_progress_bar=self.show_progress_bar,
num_workers=self.num_workers,
# activation_fct=self.activation_fct,
# apply_softmax=self.apply_softmax,
convert_to_tensor=True
)
top_k = self.top_n if self.top_n < len(results) else len(results)
values, indices = results.topk(top_k)
final_results = []
for value, index in zip(values, indices):
doc = doc_list[index]
doc.metadata["relevance_score"] = value
final_results.append(doc)
return final_results
if __name__ == "__main__":
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import embedding_device
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=3,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)