|
|
|
|
|
|
|
|
|
import configparser
|
|
import logging
|
|
import os
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client
|
|
from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_article
|
|
from App_Function_Libraries.DB.DB_Manager import search_db, fetch_keywords_for_media
|
|
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
|
|
|
|
|
import openai
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openai.api_key = "your-openai-api-key"
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
config_path = os.path.join(current_dir, 'Config_Files', 'config.txt')
|
|
|
|
config = configparser.ConfigParser()
|
|
|
|
config.read('config.txt')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None) -> Dict[str, Any]:
|
|
try:
|
|
|
|
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
|
|
|
|
|
logging.debug(f"Using embedding provider: {embedding_provider}")
|
|
|
|
|
|
keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else []
|
|
logging.debug(f"enhanced_rag_pipeline - Keywords: {keyword_list}")
|
|
|
|
|
|
relevant_media_ids = fetch_relevant_media_ids(keyword_list) if keyword_list else None
|
|
logging.debug(f"enhanced_rag_pipeline - relevant media IDs: {relevant_media_ids}")
|
|
|
|
|
|
vector_results = perform_vector_search(query, relevant_media_ids)
|
|
logging.debug(f"enhanced_rag_pipeline - Vector search results: {vector_results}")
|
|
|
|
|
|
fts_results = perform_full_text_search(query, relevant_media_ids)
|
|
logging.debug(f"enhanced_rag_pipeline - Full-text search results: {fts_results}")
|
|
|
|
|
|
all_results = vector_results + fts_results
|
|
|
|
|
|
apply_re_ranking = False
|
|
if apply_re_ranking:
|
|
|
|
pass
|
|
|
|
context = "\n".join([result['content'] for result in all_results[:10]])
|
|
logging.debug(f"Context length: {len(context)}")
|
|
logging.debug(f"Context: {context[:200]}")
|
|
|
|
answer = generate_answer(api_choice, context, query)
|
|
|
|
if not all_results:
|
|
logging.info(f"No results found. Query: {query}, Keywords: {keywords}")
|
|
return {
|
|
"answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
|
|
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
|
}
|
|
|
|
return {
|
|
"answer": answer,
|
|
"context": context
|
|
}
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
|
return {
|
|
"answer": "An error occurred while processing your request.",
|
|
"context": ""
|
|
}
|
|
|
|
|
|
def generate_answer(api_choice: str, context: str, query: str) -> str:
|
|
logging.debug("Entering generate_answer function")
|
|
config = load_comprehensive_config()
|
|
logging.debug(f"Config sections: {config.sections()}")
|
|
prompt = f"Context: {context}\n\nQuestion: {query}"
|
|
if api_choice == "OpenAI":
|
|
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_openai
|
|
return summarize_with_openai(config['API']['openai_api_key'], prompt, "")
|
|
elif api_choice == "Anthropic":
|
|
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_anthropic
|
|
return summarize_with_anthropic(config['API']['anthropic_api_key'], prompt, "")
|
|
elif api_choice == "Cohere":
|
|
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_cohere
|
|
return summarize_with_cohere(config['API']['cohere_api_key'], prompt, "")
|
|
elif api_choice == "Groq":
|
|
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_groq
|
|
return summarize_with_groq(config['API']['groq_api_key'], prompt, "")
|
|
elif api_choice == "OpenRouter":
|
|
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_openrouter
|
|
return summarize_with_openrouter(config['API']['openrouter_api_key'], prompt, "")
|
|
elif api_choice == "HuggingFace":
|
|
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_huggingface
|
|
return summarize_with_huggingface(config['API']['huggingface_api_key'], prompt, "")
|
|
elif api_choice == "DeepSeek":
|
|
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_deepseek
|
|
return summarize_with_deepseek(config['API']['deepseek_api_key'], prompt, "")
|
|
elif api_choice == "Mistral":
|
|
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_mistral
|
|
return summarize_with_mistral(config['API']['mistral_api_key'], prompt, "")
|
|
elif api_choice == "Local-LLM":
|
|
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_local_llm
|
|
return summarize_with_local_llm(config['API']['local_llm_path'], prompt, "")
|
|
elif api_choice == "Llama.cpp":
|
|
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_llama
|
|
return summarize_with_llama(config['API']['llama_api_key'], prompt, "")
|
|
elif api_choice == "Kobold":
|
|
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_kobold
|
|
return summarize_with_kobold(config['API']['kobold_api_key'], prompt, "")
|
|
elif api_choice == "Ooba":
|
|
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_oobabooga
|
|
return summarize_with_oobabooga(config['API']['ooba_api_key'], prompt, "")
|
|
elif api_choice == "TabbyAPI":
|
|
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_tabbyapi
|
|
return summarize_with_tabbyapi(config['API']['tabby_api_key'], prompt, "")
|
|
elif api_choice == "vLLM":
|
|
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_vllm
|
|
return summarize_with_vllm(config['API']['vllm_api_key'], prompt, "")
|
|
elif api_choice == "ollama":
|
|
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_ollama
|
|
return summarize_with_ollama(config['API']['ollama_api_key'], prompt, "")
|
|
else:
|
|
raise ValueError(f"Unsupported API choice: {api_choice}")
|
|
|
|
def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
|
|
all_collections = chroma_client.list_collections()
|
|
vector_results = []
|
|
for collection in all_collections:
|
|
collection_results = vector_search(collection.name, query, k=5)
|
|
filtered_results = [
|
|
result for result in collection_results
|
|
if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids
|
|
]
|
|
vector_results.extend(filtered_results)
|
|
return vector_results
|
|
|
|
|
|
def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
|
|
fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
|
|
filtered_fts_results = [
|
|
{
|
|
"content": result['content'],
|
|
"metadata": {"media_id": result['id']}
|
|
}
|
|
for result in fts_results
|
|
if relevant_media_ids is None or result['id'] in relevant_media_ids
|
|
]
|
|
return filtered_fts_results
|
|
|
|
|
|
def fetch_relevant_media_ids(keywords: List[str]) -> List[int]:
|
|
relevant_ids = set()
|
|
try:
|
|
for keyword in keywords:
|
|
media_ids = fetch_keywords_for_media(keyword)
|
|
relevant_ids.update(media_ids)
|
|
except Exception as e:
|
|
logging.error(f"Error fetching relevant media IDs: {str(e)}")
|
|
return list(relevant_ids)
|
|
|
|
|
|
def filter_results_by_keywords(results: List[Dict[str, Any]], keywords: List[str]) -> List[Dict[str, Any]]:
|
|
if not keywords:
|
|
return results
|
|
|
|
filtered_results = []
|
|
for result in results:
|
|
try:
|
|
metadata = result.get('metadata', {})
|
|
if metadata is None:
|
|
logging.warning(f"No metadata found for result: {result}")
|
|
continue
|
|
if not isinstance(metadata, dict):
|
|
logging.warning(f"Unexpected metadata type: {type(metadata)}. Expected dict.")
|
|
continue
|
|
|
|
media_id = metadata.get('media_id')
|
|
if media_id is None:
|
|
logging.warning(f"No media_id found in metadata: {metadata}")
|
|
continue
|
|
|
|
media_keywords = fetch_keywords_for_media(media_id)
|
|
if any(keyword.lower() in [mk.lower() for mk in media_keywords] for keyword in keywords):
|
|
filtered_results.append(result)
|
|
except Exception as e:
|
|
logging.error(f"Error processing result: {result}. Error: {str(e)}")
|
|
|
|
return filtered_results
|
|
|
|
|
|
def extract_media_id_from_result(result: str) -> Optional[int]:
|
|
|
|
|
|
try:
|
|
return int(result.split('_')[0])
|
|
except (IndexError, ValueError):
|
|
logging.error(f"Failed to extract media_id from result: {result}")
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|