Spaces:
Sleeping
Sleeping
from huggingface_hub import login | |
from fastapi import FastAPI, Depends, HTTPException | |
import logging | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer | |
from services.qdrant_searcher import QdrantSearcher | |
from services.openai_service import generate_rag_response | |
from utils.auth import token_required | |
from dotenv import load_dotenv | |
import os | |
# Load environment variables from .env file | |
load_dotenv() | |
# Initialize FastAPI application | |
app = FastAPI() | |
# Set the cache directory for Hugging Face | |
os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
# Ensure the cache directory exists | |
cache_dir = os.environ["HF_HOME"] | |
if not os.path.exists(cache_dir): | |
os.makedirs(cache_dir) | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
# Load Hugging Face token from environment variable | |
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN') | |
if huggingface_token: | |
try: | |
login(token=huggingface_token, add_to_git_credential=True) | |
logging.info("Successfully logged into Hugging Face Hub.") | |
except Exception as e: | |
logging.error(f"Failed to log into Hugging Face Hub: {e}") | |
raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.") | |
else: | |
raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.") | |
# Initialize the Qdrant searcher | |
qdrant_url = os.getenv('QDRANT_URL') | |
access_token = os.getenv('QDRANT_ACCESS_TOKEN') | |
if not qdrant_url or not access_token: | |
raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.") | |
# Initialize the SentenceTransformer model | |
try: | |
encoder = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5') | |
logging.info("Successfully loaded the SentenceTransformer model.") | |
except Exception as e: | |
logging.error(f"Failed to load the SentenceTransformer model: {e}") | |
raise HTTPException(status_code=500, detail="Failed to load the SentenceTransformer model.") | |
# Initialize the Qdrant searcher | |
searcher = QdrantSearcher(encoder, qdrant_url, access_token) | |
# Define the request body models | |
class SearchDocumentsRequest(BaseModel): | |
query: str | |
limit: int = 3 | |
class GenerateRAGRequest(BaseModel): | |
search_query: str | |
# Define the search documents endpoint | |
async def search_documents( | |
body: SearchDocumentsRequest, | |
credentials: tuple = Depends(token_required) | |
): | |
customer_id, user_id = credentials | |
if not customer_id or not user_id: | |
logging.error("Failed to extract customer_id or user_id from the JWT token.") | |
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") | |
logging.info("Received request to search documents") | |
try: | |
hits, error = searcher.search_documents("documents", body.query, user_id, body.limit) | |
if error: | |
logging.error(f"Search documents error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
return hits | |
except Exception as e: | |
logging.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Define the generate RAG response endpoint | |
async def generate_rag_response_api( | |
body: GenerateRAGRequest, | |
credentials: tuple = Depends(token_required) | |
): | |
customer_id, user_id = credentials | |
if not customer_id or not user_id: | |
logging.error("Failed to extract customer_id or user_id from the JWT token.") | |
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") | |
logging.info("Received request to generate RAG response") | |
try: | |
hits, error = searcher.search_documents("documents", body.search_query, user_id) | |
if error: | |
logging.error(f"Search documents error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
response, error = generate_rag_response(hits, body.search_query) | |
if error: | |
logging.error(f"Generate RAG response error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
return {"response": response} | |
except Exception as e: | |
logging.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run(app, host='0.0.0.0', port=8000) | |