import os import numpy as np import pandas as pd from transformers import pipeline from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from utils.logger import setup_logger from utils.model_loader import ModelLoader logger = setup_logger(__name__) class RAGSystem: def __init__(self, csv_path="apparel.csv"): try: # Initialize the sentence transformer model self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # Initialize the QA pipeline self.qa_pipeline = pipeline( "question-answering", model="distilbert-base-cased-distilled-squad", tokenizer="distilbert-base-cased-distilled-squad" ) self.setup_system(csv_path) except Exception as e: logger.error(f"Failed to initialize RAGSystem: {str(e)}") raise def setup_system(self, csv_path): if not os.path.exists(csv_path): raise FileNotFoundError(f"CSV file not found at {csv_path}") try: # Load and preprocess documents self.documents = pd.read_csv(csv_path) self.texts = self.documents['Title'].astype(str).tolist() # Create embeddings for all documents self.embeddings = self.embedder.encode(self.texts) logger.info(f"Successfully loaded {len(self.texts)} documents") except Exception as e: logger.error(f"Failed to setup RAG system: {str(e)}") raise def get_relevant_documents(self, query, top_k=5): try: # Get query embedding query_embedding = self.embedder.encode([query]) # Calculate similarities similarities = cosine_similarity(query_embedding, self.embeddings)[0] # Get top k most similar documents top_indices = np.argsort(similarities)[-top_k:][::-1] return [self.texts[i] for i in top_indices] except Exception as e: logger.error(f"Error retrieving relevant documents: {str(e)}") return [] def process_query(self, query): try: # Get relevant documents relevant_docs = self.get_relevant_documents(query) if not relevant_docs: return "No relevant documents found." # Combine retrieved documents into context context = " ".join(relevant_docs) # Prepare QA input qa_input = { "question": query, "context": context[:512] # Limit context length for the model } # Get answer using QA pipeline answer = self.qa_pipeline(qa_input) return answer['answer'] except Exception as e: logger.error(f"Error processing query: {str(e)}") return f"Failed to process query: {str(e)}"