pro-search-api / app.py
vhr1007's picture
Update app.py
21c27da verified
raw
history blame
4.49 kB
from huggingface_hub import login
from fastapi import FastAPI, Depends, HTTPException
import logging
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from services.qdrant_searcher import QdrantSearcher
from services.openai_service import generate_rag_response
from utils.auth import token_required
from dotenv import load_dotenv
import os
# Load environment variables from .env file
load_dotenv()
# Initialize FastAPI application
app = FastAPI()
# Set the cache directory for Hugging Face
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Ensure the cache directory exists
cache_dir = os.environ["HF_HOME"]
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Setup logging
logging.basicConfig(level=logging.INFO)
# Load Hugging Face token from environment variable
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN')
if huggingface_token:
try:
login(token=huggingface_token, add_to_git_credential=True)
logging.info("Successfully logged into Hugging Face Hub.")
except Exception as e:
logging.error(f"Failed to log into Hugging Face Hub: {e}")
raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.")
else:
raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.")
# Initialize the Qdrant searcher
qdrant_url = os.getenv('QDRANT_URL')
access_token = os.getenv('QDRANT_ACCESS_TOKEN')
if not qdrant_url or not access_token:
raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.")
# Initialize the SentenceTransformer model
try:
encoder = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5')
logging.info("Successfully loaded the SentenceTransformer model.")
except Exception as e:
logging.error(f"Failed to load the SentenceTransformer model: {e}")
raise HTTPException(status_code=500, detail="Failed to load the SentenceTransformer model.")
# Initialize the Qdrant searcher
searcher = QdrantSearcher(encoder, qdrant_url, access_token)
# Define the request body models
class SearchDocumentsRequest(BaseModel):
query: str
limit: int = 3
class GenerateRAGRequest(BaseModel):
search_query: str
# Define the search documents endpoint
@app.post("/api/search-documents")
async def search_documents(
body: SearchDocumentsRequest,
credentials: tuple = Depends(token_required)
):
customer_id, user_id = credentials
if not customer_id or not user_id:
logging.error("Failed to extract customer_id or user_id from the JWT token.")
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
logging.info("Received request to search documents")
try:
hits, error = searcher.search_documents("documents", body.query, user_id, body.limit)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
return hits
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Define the generate RAG response endpoint
@app.post("/api/generate-rag-response")
async def generate_rag_response_api(
body: GenerateRAGRequest,
credentials: tuple = Depends(token_required)
):
customer_id, user_id = credentials
if not customer_id or not user_id:
logging.error("Failed to extract customer_id or user_id from the JWT token.")
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
logging.info("Received request to generate RAG response")
try:
hits, error = searcher.search_documents("documents", body.search_query, user_id)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
response, error = generate_rag_response(hits, body.search_query)
if error:
logging.error(f"Generate RAG response error: {error}")
raise HTTPException(status_code=500, detail=error)
return {"response": response}
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8000)