from qdrant_client import QdrantClient from qdrant_client.http.models import ScoredPoint from embedding import Embedding from model.document import Document from model.record import Record from model.user import User from qdrant_client.http import models import uuid import tqdm class Index: type: str def load_or_update_document(self, user: User, document: Document, progress: tqdm.tqdm = None): pass def remove_document(self, user: User, document: Document): pass def query_index(self, user: User, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: pass def query_document(self, user: User, document: Document, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: pass def contains(self, user: User, document: Document) -> bool: pass class QDrantVectorStore(Index): _client: QdrantClient _embedding: Embedding collection_name: str batch_size: int = 10 type: str = 'qdrant' def __init__( self, client: QdrantClient, embedding: Embedding, collection_name: str): self._embedding = embedding self.collection_name = collection_name self._client = client def _response_to_records(self, response: list[ScoredPoint]) -> list[Record]: for point in response: meta_data = point.payload['meta_data'] yield Record( embedding=point.vector, meta_data= meta_data, content=point.payload['content'], document_id=point.payload['document_id'], timestamp=point.payload['timestamp'], ) def create_collection(self): self._client.recreate_collection( collection_name=self.collection_name, vectors_config=models.VectorParams( size=self._embedding.vector_size, distance=models.Distance.COSINE), ) def if_collection_exists(self) -> bool: try: self._client.get_collection(self.collection_name) return True except Exception: return False def create_collection_if_not_exists(self): if not self.if_collection_exists(): self.create_collection() def load_or_update_document(self, user: User, document: Document, progress: tqdm.tqdm = None): self.create_collection_if_not_exists() if self.contains(user, document): self.remove_document(user, document) group_id = user.user_name # upsert records in batch records = document.load_records() records = list(records) batch_range = range(0, len(records), self.batch_size) if progress is not None: batch_range = progress(batch_range) for i in batch_range: batch = records[i:i+self.batch_size] uuids = [str(uuid.uuid4()) for _ in batch] payloads = [{ 'content': record.content, 'meta_data': record.meta_data, 'document_id': record.document_id, 'group_id': group_id, 'timestamp': record.timestamp, } for record in batch] vectors = [record.embedding for record in batch] self._client.upsert( collection_name=self.collection_name, points=models.Batch( payloads=payloads, ids=uuids, vectors=vectors, ), ) def remove_document(self, user: User, document: Document): if not self.if_collection_exists(): return document_id = document.name self._client.delete( collection_name=self.collection_name, points_selector=models.FilterSelector( filter=models.Filter( must=[ models.FieldCondition( key="document_id", match=models.MatchValue(value=document_id) ), models.FieldCondition( key="group_id", match=models.MatchValue( value=user.user_name, ), ) ] ) ) ) def contains(self, user: User, document: Document) -> bool: document_id = document.name group_id = user.user_name count = self._client.count( collection_name=self.collection_name, count_filter=models.Filter( must=[ models.FieldCondition( key="document_id", match=models.MatchValue(value=document_id) ), models.FieldCondition( key="group_id", match=models.MatchValue( value=group_id, ), ) ] ), exact=True, ) return count.count > 0 def query_index(self, user: User, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: if not self.if_collection_exists(): return [] response = self._client.search( collection_name=self.collection_name, query_vector=self._embedding.generate_embedding(query), limit=top_k, query_filter= models.Filter( must=[ models.FieldCondition( key="group_id", match=models.MatchValue( value=user.user_name, ), ) ] ), score_threshold=threshold, ) return self._response_to_records(response) def query_document(self, user: User, document: Document, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: if not self.if_collection_exists(): return [] response = self._client.search( collection_name=self.collection_name, query_vector=self._embedding.generate_embedding(query), limit=top_k, query_filter= models.Filter( must=[ models.FieldCondition( key="document_id", match=models.MatchValue(value=document.name) ), models.FieldCondition( key="group_id", match=models.MatchValue(value=user.user_name), ) ] ), score_threshold=threshold, ) return self._response_to_records(response)