Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
import numpy as np | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Filter, FieldCondition | |
from collections import defaultdict | |
class QdrantSearcher: | |
def __init__(self, qdrant_url, access_token): | |
self.client = QdrantClient(url=qdrant_url, api_key=access_token) | |
def search_documents(self, collection_name, query_embedding, user_id, limit=3,similarity_threshold=0.6, file_id=None): | |
logging.info("Starting document search") | |
# Ensure the query_embedding is in the correct format (flat list of floats) | |
if isinstance(query_embedding, torch.Tensor): | |
query_embedding = query_embedding.detach().numpy().flatten().tolist() | |
elif isinstance(query_embedding, np.ndarray): | |
query_embedding = query_embedding.flatten().tolist() | |
else: | |
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray") | |
# Validate that all elements in the query_vector are floats | |
if not all(isinstance(x, float) for x in query_embedding): | |
raise ValueError("All elements in query_embedding must be of type float") | |
filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})] | |
if file_id: | |
filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id})) | |
# Filter by user_id | |
query_filter = Filter(must=filter_conditions) | |
logging.info(f"Performing search using the precomputed embeddings for user_id: {user_id}") | |
try: | |
hits = self.client.search( | |
collection_name=collection_name, | |
query_vector=query_embedding, | |
limit=limit, | |
query_filter=query_filter | |
) | |
except Exception as e: | |
logging.error(f"Error during Qdrant search: {e}") | |
return None, str(e) | |
filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold] | |
if not filtered_hits: | |
logging.info("No documents found for the given query") | |
return None, "No documents found for the given query." | |
hits_list = [] | |
for hit in filtered_hits: | |
hit_info = { | |
"id": hit.id, | |
"score": hit.score, | |
"file_id": hit.payload.get('file_id'), | |
"file_name": hit.payload.get('file_name'), | |
"organization_id": hit.payload.get('organization_id'), | |
"chunk_index": hit.payload.get('chunk_index'), | |
"chunk_text": hit.payload.get('chunk_text'), | |
"s3_bucket_key": hit.payload.get('s3_bucket_key') | |
} | |
hits_list.append(hit_info) | |
logging.info(f"Document search completed with {len(hits_list)} hits") | |
logging.info(f"Hits: {hits_list}") | |
return hits_list, None | |
def search_documents_grouped(self, collection_name, query_embedding, user_id, limit=60, similarity_threshold=0.6, file_id=None): | |
logging.info("Starting grouped document search") | |
if isinstance(query_embedding, torch.Tensor): | |
query_embedding = query_embedding.detach().numpy().flatten().tolist() | |
elif isinstance(query_embedding, np.ndarray): | |
query_embedding = query_embedding.flatten().tolist() | |
else: | |
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray") | |
if not all(isinstance(x, float) for x in query_embedding): | |
raise ValueError("All elements in query_embedding must be of type float") | |
#query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})]) | |
filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})] | |
if file_id: | |
filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id})) | |
# Filter by user_id | |
query_filter = Filter(must=filter_conditions) | |
logging.info(f"Performing grouped search using the precomputed embeddings for user_id: {user_id}") | |
try: | |
hits = self.client.search( | |
collection_name=collection_name, | |
query_vector=query_embedding, | |
limit=limit, | |
query_filter=query_filter | |
) | |
except Exception as e: | |
logging.error(f"Error during Qdrant search: {e}") | |
return None, str(e) | |
#filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold] | |
if not hits: | |
logging.info("No documents found for the given query") | |
return None, "No documents found for the given query." | |
# Group hits by filename and calculate average score | |
grouped_hits = defaultdict(list) | |
for hit in hits: | |
grouped_hits[hit.payload.get('file_name')].append(hit.score) | |
grouped_results = [] | |
for file_name, scores in grouped_hits.items(): | |
average_score = sum(scores) / len(scores) | |
grouped_results.append({ | |
"file_name": file_name, | |
"average_score": average_score | |
}) | |
logging.info(f"Grouped search completed with {len(grouped_results)} results") | |
logging.info(f"Grouped Hits: {grouped_results}") | |
return grouped_results, None | |