Bot_Test / script /vector_db.py
dsmultimedika's picture
feat: add features to chat images
1bdfad3
from llama_index.core import VectorStoreIndex
from llama_index.core import StorageContext
# from llama_index.core import Settings
from pinecone import Pinecone, ServerlessSpec
from llama_index.vector_stores.pinecone import PineconeVectorStore
from fastapi import HTTPException, status
from fastapi.responses import JSONResponse
from config import PINECONE_CONFIG
from math import ceil
import numpy as np
import logging
class IndexManager:
def __init__(self, index_name: str = "multimodal-index"):
self.vector_index = None
self.index_name = index_name
self.client = self._get_pinecone_client()
self.pinecone_index = self._create_pinecone_index()
def _get_pinecone_client(self):
"""Initialize and return the Pinecone client."""
# api_key = os.getenv("PINECONE_API_KEY")
api_key = PINECONE_CONFIG.PINECONE_API_KEY
if not api_key:
raise ValueError(
"Pinecone API key is missing. Please set it in environment variables."
)
return Pinecone(api_key=api_key)
def _create_pinecone_index(self):
"""Create Pinecone index if it doesn't already exist."""
if self.index_name not in self.client.list_indexes().names():
self.client.create_index(
name=self.index_name,
dimension=3072,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
)
return self.client.Index(self.index_name)
def _initialize_vector_store(self) -> StorageContext:
"""Initialize and return the vector store with the Pinecone index."""
vector_store = PineconeVectorStore(pinecone_index=self.pinecone_index)
return StorageContext.from_defaults(vector_store=vector_store)
def build_indexes(self, nodes):
"""Build vector and tree indexes from nodes."""
try:
storage_context = self._initialize_vector_store()
self.vector_index = VectorStoreIndex(nodes, storage_context=storage_context)
except HTTPException as http_exc:
raise http_exc # Re-return JSONResponses to ensure FastAPI handles them
except Exception as e:
print("Error building index : ",e)
raise JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=f"Error loading existing indexes: {str(e)}"
)
def get_ids_from_query(self, input_vector, title):
print("Searching Pinecone...")
print(title)
new_ids = set() # Initialize new_ids outside the loop
while True:
results = self.pinecone_index.query(
vector=input_vector,
top_k=10000,
filter={
"title": {"$eq": f"{title}"},
},
)
ids = set()
for result in results['matches']:
ids.add(result['id'])
# Check if there's any overlap between ids and new_ids
if ids.issubset(new_ids):
break
else:
new_ids.update(ids) # Add all new ids to new_ids
return new_ids
def get_all_ids_from_index(self, title):
num_dimensions = 1536
num_vectors = self.pinecone_index.describe_index_stats(
)["total_vector_count"]
input_vector = np.random.rand(num_dimensions).tolist()
ids = self.get_ids_from_query(input_vector, title)
return ids
def delete_vector_database(self, title):
try :
batch_size = 1000
all_ids = self.get_all_ids_from_index(title)
all_ids = list(all_ids)
# Split ids into chunks of batch_size
num_batches = ceil(len(all_ids) / batch_size)
for i in range(num_batches):
# Fetch a batch of IDs
batch_ids = all_ids[i * batch_size: (i + 1) * batch_size]
self.pinecone_index.delete(ids=batch_ids)
logging.info(f"delete from id {i * batch_size} to {(i + 1) * batch_size} successful")
except Exception as e:
return JSONResponse(status_code=500, content="An error occurred while delete metadata")
def update_vector_database(self, current_reference, new_reference):
reference = new_reference
all_ids = self.get_all_ids_from_index(current_reference['title'])
all_ids = list(all_ids)
for id in all_ids:
self.pinecone_index.update(
id=id,
set_metadata=reference
)
def load_existing_indexes(self):
"""Load existing indexes from Pinecone."""
try:
client = self._get_pinecone_client()
pinecone_index = client.Index(self.index_name)
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
retriever = VectorStoreIndex.from_vector_store(vector_store)
return retriever
except Exception as e:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=f"Error loading existing indexes: {str(e)}"
)