from huggingface_hub import login from fastapi import FastAPI, Depends, HTTPException import logging from pydantic import BaseModel from transformers import AutoTokenizer, AutoModel from services.qdrant_searcher import QdrantSearcher from services.openai_service import generate_rag_response from utils.auth import token_required from dotenv import load_dotenv import os import torch from utils.auth_x import x_api_key_auth # Load environment variables from .env file load_dotenv() # Initialize FastAPI application app = FastAPI() # Set the cache directory for Hugging Face os.environ["HF_HOME"] = "/tmp/huggingface_cache" # Ensure the cache directory exists hf_home_dir = os.environ["HF_HOME"] if not os.path.exists(hf_home_dir): os.makedirs(hf_home_dir) # Setup logging using Python's standard logging library logging.basicConfig(level=logging.INFO) # Load Hugging Face token from environment variable huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN') if huggingface_token: try: login(token=huggingface_token, add_to_git_credential=True) logging.info("Successfully logged into Hugging Face Hub.") except Exception as e: logging.error(f"Failed to log into Hugging Face Hub: {e}") raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.") else: raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.") # Initialize the Qdrant searcher qdrant_url = os.getenv('QDRANT_URL') access_token = os.getenv('QDRANT_ACCESS_TOKEN') if not qdrant_url or not access_token: raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.") # Load the model and tokenizer with trust_remote_code=True try: cache_folder = os.path.join(hf_home_dir, "transformers_cache") # Load the tokenizer and model with trust_remote_code=True tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) logging.info("Successfully loaded the model and tokenizer with transformers.") # Initialize the Qdrant searcher after the model is successfully loaded global searcher # Ensure searcher is accessible globally if needed searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token) except Exception as e: logging.error(f"Failed to load the model or initialize searcher: {e}") raise HTTPException(status_code=500, detail="Failed to load the custom model or initialize searcher.") # Function to embed text using the model def embed_text(text): inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") outputs = model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling return embeddings.detach().numpy() # Define the request body models class SearchDocumentsRequest(BaseModel): query: str limit: int = 3 class GenerateRAGRequest(BaseModel): search_query: str class XApiKeyRequest(BaseModel): organization_id: str user_id: str search_query: str @app.get("/") async def root(): return {"message": "Welcome to the Search and RAG API!, go to relevant address for API request"} # Define the search documents endpoint @app.post("/api/search-documents") async def search_documents( body: SearchDocumentsRequest, credentials: tuple = Depends(token_required) ): customer_id, user_id = credentials if not customer_id or not user_id: logging.error("Failed to extract customer_id or user_id from the JWT token.") raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") logging.info("Received request to search documents") try: logging.info("Starting document search") # Encode the query using the custom embedding function query_embedding = embed_text(body.query) print(body.query) collection_name = "embed" # Use the collection name where the embeddings are stored logging.info("Performing search using the precomputed embeddings") # Perform search using the precomputed embeddings hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit) if error: logging.error(f"Search documents error: {error}") raise HTTPException(status_code=500, detail=error) return hits except Exception as e: logging.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) # Define the generate RAG response endpoint @app.post("/api/generate-rag-response") async def generate_rag_response_api( body: GenerateRAGRequest, credentials: tuple = Depends(token_required) ): customer_id, user_id = credentials if not customer_id or not user_id: logging.error("Failed to extract customer_id or user_id from the JWT token.") raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") logging.info("Received request to generate RAG response") try: logging.info("Starting document search") # Encode the query using the custom embedding function query_embedding = embed_text(body.search_query) print(body.search_query) collection_name = "embed" # Use the collection name where the embeddings are stored # Perform search using the precomputed embeddings hits, error = searcher.search_documents(collection_name, query_embedding, user_id) if error: logging.error(f"Search documents error: {error}") raise HTTPException(status_code=500, detail=error) logging.info("Generating RAG response") # Generate the RAG response using the retrieved documents response, error = generate_rag_response(hits, body.search_query) if error: logging.error(f"Generate RAG response error: {error}") raise HTTPException(status_code=500, detail=error) return {"response": response} except Exception as e: logging.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/search-documents/v1") async def search_documents_x_api_key( body: XApiKeyRequest, authorized: bool = Depends(x_api_key_auth) ): if not authorized: raise HTTPException(status_code=401, detail="Unauthorized") organization_id = body.organization_id user_id = body.user_id logging.info(f'search query {body.search_query}') logging.info(f"organization_id: {organization_id}, user_id: {user_id}") logging.info("Received request to search documents with x-api-key auth") try: logging.info("Starting document search") # Encode the query using the custom embedding function query_embedding = embed_text(body.search_query) collection_name = "embed" # Use the collection name where the embeddings are stored # Perform search using the precomputed embeddings hits, error = searcher.search_documents(collection_name, query_embedding, user_id, limit=3) if error: logging.error(f"Search documents error: {error}") raise HTTPException(status_code=500, detail=error) logging.info(f"Document search completed with {len(hits)} hits") return hits except Exception as e: logging.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/generate-rag-response/v1") async def generate_rag_response_x_api_key( body: XApiKeyRequest, authorized: bool = Depends(x_api_key_auth) ): # Assuming x_api_key_auth validates the key if not authorized: raise HTTPException(status_code=401, detail="Unauthorized") organization_id = body.organization_id user_id = body.user_id logging.info(f'search query {body.search_query}') logging.info(f"organization_id: {organization_id}, user_id: {user_id}") logging.info("Received request to generate RAG response with x-api-key auth") try: logging.info("Starting document search") # Encode the query using the custom embedding function query_embedding = embed_text(body.search_query) collection_name = "embed" # Use the collection name where the embeddings are stored # Perform search using the precomputed embeddings hits, error = searcher.search_documents(collection_name, query_embedding, user_id) if error: logging.error(f"Search documents error: {error}") raise HTTPException(status_code=500, detail=error) logging.info("Generating RAG response") # Generate the RAG response using the retrieved documents response, error = generate_rag_response(hits, body.search_query) if error: logging.error(f"Generate RAG response error: {error}") raise HTTPException(status_code=500, detail=error) return {"response": response} except Exception as e: logging.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == '__main__': import uvicorn uvicorn.run(app, host='0.0.0.0', port=8000)