pro-search-api / services /qdrant_searcher.py
vhr1007
new_version_changes3.0
b687ff9
import logging
import torch
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition
from collections import defaultdict
class QdrantSearcher:
def __init__(self, qdrant_url, access_token):
self.client = QdrantClient(url=qdrant_url, api_key=access_token)
def search_documents(self, collection_name, query_embedding, user_id, limit=3,similarity_threshold=0.6, file_id=None):
logging.info("Starting document search")
# Ensure the query_embedding is in the correct format (flat list of floats)
if isinstance(query_embedding, torch.Tensor):
query_embedding = query_embedding.detach().numpy().flatten().tolist()
elif isinstance(query_embedding, np.ndarray):
query_embedding = query_embedding.flatten().tolist()
else:
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")
# Validate that all elements in the query_vector are floats
if not all(isinstance(x, float) for x in query_embedding):
raise ValueError("All elements in query_embedding must be of type float")
filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})]
if file_id:
filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id}))
# Filter by user_id
query_filter = Filter(must=filter_conditions)
logging.info(f"Performing search using the precomputed embeddings for user_id: {user_id}")
try:
hits = self.client.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=limit,
query_filter=query_filter
)
except Exception as e:
logging.error(f"Error during Qdrant search: {e}")
return None, str(e)
filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold]
if not filtered_hits:
logging.info("No documents found for the given query")
return None, "No documents found for the given query."
hits_list = []
for hit in filtered_hits:
hit_info = {
"id": hit.id,
"score": hit.score,
"file_id": hit.payload.get('file_id'),
"file_name": hit.payload.get('file_name'),
"organization_id": hit.payload.get('organization_id'),
"chunk_index": hit.payload.get('chunk_index'),
"chunk_text": hit.payload.get('chunk_text'),
"s3_bucket_key": hit.payload.get('s3_bucket_key')
}
hits_list.append(hit_info)
logging.info(f"Document search completed with {len(hits_list)} hits")
logging.info(f"Hits: {hits_list}")
return hits_list, None
def search_documents_grouped(self, collection_name, query_embedding, user_id, limit=60, similarity_threshold=0.6, file_id=None):
logging.info("Starting grouped document search")
if isinstance(query_embedding, torch.Tensor):
query_embedding = query_embedding.detach().numpy().flatten().tolist()
elif isinstance(query_embedding, np.ndarray):
query_embedding = query_embedding.flatten().tolist()
else:
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")
if not all(isinstance(x, float) for x in query_embedding):
raise ValueError("All elements in query_embedding must be of type float")
#query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})]
if file_id:
filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id}))
# Filter by user_id
query_filter = Filter(must=filter_conditions)
logging.info(f"Performing grouped search using the precomputed embeddings for user_id: {user_id}")
try:
hits = self.client.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=limit,
query_filter=query_filter
)
except Exception as e:
logging.error(f"Error during Qdrant search: {e}")
return None, str(e)
#filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold]
if not hits:
logging.info("No documents found for the given query")
return None, "No documents found for the given query."
# Group hits by filename and calculate average score
grouped_hits = defaultdict(list)
for hit in hits:
grouped_hits[hit.payload.get('file_name')].append(hit.score)
grouped_results = []
for file_name, scores in grouped_hits.items():
average_score = sum(scores) / len(scores)
grouped_results.append({
"file_name": file_name,
"average_score": average_score
})
logging.info(f"Grouped search completed with {len(grouped_results)} results")
logging.info(f"Grouped Hits: {grouped_results}")
return grouped_results, None