pro-search-api / app.py
vhr1007
x_api_key new end point
bcd2179
raw
history blame
9.48 kB
from huggingface_hub import login
from fastapi import FastAPI, Depends, HTTPException
import logging
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
from services.qdrant_searcher import QdrantSearcher
from services.openai_service import generate_rag_response
from utils.auth import token_required
from dotenv import load_dotenv
import os
import torch
from utils.auth_x import x_api_key_auth
# Load environment variables from .env file
load_dotenv()
# Initialize FastAPI application
app = FastAPI()
# Set the cache directory for Hugging Face
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Ensure the cache directory exists
hf_home_dir = os.environ["HF_HOME"]
if not os.path.exists(hf_home_dir):
os.makedirs(hf_home_dir)
# Setup logging using Python's standard logging library
logging.basicConfig(level=logging.INFO)
# Load Hugging Face token from environment variable
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN')
if huggingface_token:
try:
login(token=huggingface_token, add_to_git_credential=True)
logging.info("Successfully logged into Hugging Face Hub.")
except Exception as e:
logging.error(f"Failed to log into Hugging Face Hub: {e}")
raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.")
else:
raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.")
# Initialize the Qdrant searcher
qdrant_url = os.getenv('QDRANT_URL')
access_token = os.getenv('QDRANT_ACCESS_TOKEN')
if not qdrant_url or not access_token:
raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.")
# Load the model and tokenizer with trust_remote_code=True
try:
cache_folder = os.path.join(hf_home_dir, "transformers_cache")
# Load the tokenizer and model with trust_remote_code=True
tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
logging.info("Successfully loaded the model and tokenizer with transformers.")
# Initialize the Qdrant searcher after the model is successfully loaded
global searcher # Ensure searcher is accessible globally if needed
searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token)
except Exception as e:
logging.error(f"Failed to load the model or initialize searcher: {e}")
raise HTTPException(status_code=500, detail="Failed to load the custom model or initialize searcher.")
# Function to embed text using the model
def embed_text(text):
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling
return embeddings.detach().numpy()
# Define the request body models
class SearchDocumentsRequest(BaseModel):
query: str
limit: int = 3
class GenerateRAGRequest(BaseModel):
search_query: str
class XApiKeyRequest(BaseModel):
organization_id: str
user_id: str
search_query: str
@app.get("/")
async def root():
return {"message": "Welcome to the Search and RAG API!, go to relevant address for API request"}
# Define the search documents endpoint
@app.post("/api/search-documents")
async def search_documents(
body: SearchDocumentsRequest,
credentials: tuple = Depends(token_required)
):
customer_id, user_id = credentials
if not customer_id or not user_id:
logging.error("Failed to extract customer_id or user_id from the JWT token.")
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
logging.info("Received request to search documents")
try:
logging.info("Starting document search")
# Encode the query using the custom embedding function
query_embedding = embed_text(body.query)
print(body.query)
collection_name = "embed" # Use the collection name where the embeddings are stored
logging.info("Performing search using the precomputed embeddings")
# Perform search using the precomputed embeddings
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
return hits
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Define the generate RAG response endpoint
@app.post("/api/generate-rag-response")
async def generate_rag_response_api(
body: GenerateRAGRequest,
credentials: tuple = Depends(token_required)
):
customer_id, user_id = credentials
if not customer_id or not user_id:
logging.error("Failed to extract customer_id or user_id from the JWT token.")
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
logging.info("Received request to generate RAG response")
try:
logging.info("Starting document search")
# Encode the query using the custom embedding function
query_embedding = embed_text(body.search_query)
print(body.search_query)
collection_name = "embed" # Use the collection name where the embeddings are stored
# Perform search using the precomputed embeddings
hits, error = searcher.search_documents(collection_name, query_embedding, user_id)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
logging.info("Generating RAG response")
# Generate the RAG response using the retrieved documents
response, error = generate_rag_response(hits, body.search_query)
if error:
logging.error(f"Generate RAG response error: {error}")
raise HTTPException(status_code=500, detail=error)
return {"response": response}
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search-documents/v1")
async def search_documents_x_api_key(
body: XApiKeyRequest,
authorized: bool = Depends(x_api_key_auth)
):
if not authorized:
raise HTTPException(status_code=401, detail="Unauthorized")
organization_id = body.organization_id
user_id = body.user_id
logging.info(f'search query {body.search_query}')
logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
logging.info("Received request to search documents with x-api-key auth")
try:
logging.info("Starting document search")
# Encode the query using the custom embedding function
query_embedding = embed_text(body.search_query)
collection_name = "embed" # Use the collection name where the embeddings are stored
# Perform search using the precomputed embeddings
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, limit=3)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
logging.info(f"Document search completed with {len(hits)} hits")
return hits
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/generate-rag-response/v1")
async def generate_rag_response_x_api_key(
body: XApiKeyRequest,
authorized: bool = Depends(x_api_key_auth)
):
# Assuming x_api_key_auth validates the key
if not authorized:
raise HTTPException(status_code=401, detail="Unauthorized")
organization_id = body.organization_id
user_id = body.user_id
logging.info(f'search query {body.search_query}')
logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
logging.info("Received request to generate RAG response with x-api-key auth")
try:
logging.info("Starting document search")
# Encode the query using the custom embedding function
query_embedding = embed_text(body.search_query)
collection_name = "embed" # Use the collection name where the embeddings are stored
# Perform search using the precomputed embeddings
hits, error = searcher.search_documents(collection_name, query_embedding, user_id)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
logging.info("Generating RAG response")
# Generate the RAG response using the retrieved documents
response, error = generate_rag_response(hits, body.search_query)
if error:
logging.error(f"Generate RAG response error: {error}")
raise HTTPException(status_code=500, detail=error)
return {"response": response}
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8000)