oceansweep commited on
Commit
a6ecdfa
1 Parent(s): 32b7e17

Upload 3 files

Browse files
App_Function_Libraries/RAG/Embeddings_Create.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Embeddings_Create.py
2
+ # Description: Functions for Creating and managing Embeddings in ChromaDB with LLama.cpp/OpenAI/Transformers
3
+ #
4
+ # Imports:
5
+ import logging
6
+ from typing import List, Dict, Any
7
+
8
+ import numpy as np
9
+ #
10
+ # 3rd-Party Imports:
11
+ import requests
12
+ from transformers import AutoTokenizer, AutoModel
13
+ import torch
14
+ #
15
+ # Local Imports:
16
+ from App_Function_Libraries.LLM_API_Calls import get_openai_embeddings
17
+ from App_Function_Libraries.Summarization_General_Lib import summarize
18
+ from App_Function_Libraries.Utils.Utils import load_comprehensive_config
19
+ from App_Function_Libraries.Chunk_Lib import chunk_options, improved_chunking_process, determine_chunk_position
20
+ #
21
+ #
22
+ #######################################################################################################################
23
+ #
24
+ # Functions:
25
+
26
+ # FIXME - Add all globals to summarize.py
27
+ loaded_config = load_comprehensive_config()
28
+ embedding_provider = loaded_config['Embeddings']['embedding_provider']
29
+ embedding_model = loaded_config['Embeddings']['embedding_model']
30
+ embedding_api_url = loaded_config['Embeddings']['embedding_api_url']
31
+ embedding_api_key = loaded_config['Embeddings']['embedding_api_key']
32
+
33
+ # Embedding Chunking Settings
34
+ chunk_size = loaded_config['Embeddings']['chunk_size']
35
+ overlap = loaded_config['Embeddings']['overlap']
36
+
37
+
38
+ # FIXME - Add logging
39
+
40
+ # FIXME - refactor/setup to use config file & perform chunking
41
+ def create_embedding(text: str, provider: str, model: str, api_url: str = None, api_key: str = None) -> List[float]:
42
+ try:
43
+ if provider == 'openai':
44
+ embedding = get_openai_embeddings(text, model)
45
+ elif provider == 'local':
46
+ embedding = create_local_embedding(text, model, api_url, api_key)
47
+ elif provider == 'huggingface':
48
+ embedding = create_huggingface_embedding(text, model)
49
+ elif provider == 'llamacpp':
50
+ embedding = create_llamacpp_embedding(text, api_url)
51
+ else:
52
+ raise ValueError(f"Unsupported embedding provider: {provider}")
53
+
54
+ if isinstance(embedding, np.ndarray):
55
+ embedding = embedding.tolist()
56
+ elif isinstance(embedding, torch.Tensor):
57
+ embedding = embedding.detach().cpu().numpy().tolist()
58
+
59
+ return embedding
60
+
61
+ except Exception as e:
62
+ logging.error(f"Error creating embedding: {str(e)}")
63
+ raise
64
+
65
+
66
+ def create_huggingface_embedding(text: str, model: str) -> List[float]:
67
+ tokenizer = AutoTokenizer.from_pretrained(model)
68
+ model = AutoModel.from_pretrained(model)
69
+
70
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
71
+ with torch.no_grad():
72
+ outputs = model(**inputs)
73
+
74
+ embeddings = outputs.last_hidden_state.mean(dim=1)
75
+ return embeddings[0].tolist()
76
+
77
+
78
+ # FIXME
79
+ def create_stella_embeddings(text: str) -> List[float]:
80
+ if embedding_provider == 'local':
81
+ # Load the model and tokenizer
82
+ tokenizer = AutoTokenizer.from_pretrained("dunzhang/stella_en_400M_v5")
83
+ model = AutoModel.from_pretrained("dunzhang/stella_en_400M_v5")
84
+
85
+ # Tokenize and encode the text
86
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
87
+
88
+ # Generate embeddings
89
+ with torch.no_grad():
90
+ outputs = model(**inputs)
91
+
92
+ # Use the mean of the last hidden state as the sentence embedding
93
+ embeddings = outputs.last_hidden_state.mean(dim=1)
94
+
95
+ return embeddings[0].tolist() # Convert to list for consistency
96
+ elif embedding_provider == 'openai':
97
+ return get_openai_embeddings(text, embedding_model)
98
+ else:
99
+ raise ValueError(f"Unsupported embedding provider: {embedding_provider}")
100
+
101
+
102
+ def create_llamacpp_embedding(text: str, api_url: str) -> List[float]:
103
+ response = requests.post(
104
+ api_url,
105
+ json={"input": text}
106
+ )
107
+ response.raise_for_status()
108
+ return response.json()['embedding']
109
+
110
+
111
+ def create_local_embedding(text: str, model: str, api_url: str, api_key: str) -> List[float]:
112
+ response = requests.post(
113
+ api_url,
114
+ json={"text": text, "model": model},
115
+ headers={"Authorization": f"Bearer {api_key}"}
116
+ )
117
+ response.raise_for_status()
118
+ return response.json().get('embedding', None)
119
+
120
+
121
+ def chunk_for_embedding(text: str, file_name: str, api_name, custom_chunk_options: Dict[str, Any] = None) -> List[Dict[str, Any]]:
122
+ options = chunk_options.copy()
123
+ if custom_chunk_options:
124
+ options.update(custom_chunk_options)
125
+
126
+
127
+ # FIXME
128
+ if api_name is not None:
129
+ # Generate summary of the full document
130
+ full_summary = summarize(text, None, api_name, None, None, None)
131
+ else:
132
+ full_summary = "Full document summary not available."
133
+
134
+ chunks = improved_chunking_process(text, options)
135
+ total_chunks = len(chunks)
136
+
137
+ chunked_text_with_headers = []
138
+ for i, chunk in enumerate(chunks, 1):
139
+ chunk_text = chunk['text']
140
+ chunk_position = determine_chunk_position(chunk['metadata']['relative_position'])
141
+
142
+ chunk_header = f"""
143
+ Original Document: {file_name}
144
+ Full Document Summary: {full_summary}
145
+ Chunk: {i} of {total_chunks}
146
+ Position: {chunk_position}
147
+
148
+ --- Chunk Content ---
149
+ """
150
+
151
+ full_chunk_text = chunk_header + chunk_text
152
+ chunk['text'] = full_chunk_text
153
+ chunk['metadata']['file_name'] = file_name
154
+ chunked_text_with_headers.append(chunk)
155
+
156
+ return chunked_text_with_headers
157
+
158
+
159
+ def create_openai_embedding(text: str, model: str) -> List[float]:
160
+ embedding = get_openai_embeddings(text, model)
161
+ return embedding
162
+
163
+
164
+
165
+ #
166
+ # End of File.
167
+ #######################################################################################################################
App_Function_Libraries/RAG/RAG_Libary_2.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG_Library_2.py
2
+ # Description: This script contains the main RAG pipeline function and related functions for the RAG pipeline.
3
+ #
4
+ # Import necessary modules and functions
5
+ import configparser
6
+ import logging
7
+ import os
8
+ from typing import Dict, Any, List, Optional
9
+ # Local Imports
10
+ from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client
11
+ from App_Function_Libraries.Article_Extractor_Lib import scrape_article
12
+ from App_Function_Libraries.DB.DB_Manager import add_media_to_database, search_db, get_unprocessed_media, \
13
+ fetch_keywords_for_media
14
+ from App_Function_Libraries.Utils.Utils import load_comprehensive_config
15
+ #
16
+ # 3rd-Party Imports
17
+ import openai
18
+ #
19
+ ########################################################################################################################
20
+ #
21
+ # Functions:
22
+
23
+ # Initialize OpenAI client (adjust this based on your API key management)
24
+ openai.api_key = "your-openai-api-key"
25
+
26
+ # Get the directory of the current script
27
+ current_dir = os.path.dirname(os.path.abspath(__file__))
28
+ # Construct the path to the config file
29
+ config_path = os.path.join(current_dir, 'Config_Files', 'config.txt')
30
+ # Read the config file
31
+ config = configparser.ConfigParser()
32
+ # Read the configuration file
33
+ config.read('config.txt')
34
+
35
+ # Main RAG pipeline function
36
+ def rag_pipeline(url: str, query: str, api_choice=None) -> Dict[str, Any]:
37
+ try:
38
+ # Extract content
39
+ try:
40
+ article_data = scrape_article(url)
41
+ content = article_data['content']
42
+ title = article_data['title']
43
+ except Exception as e:
44
+ logging.error(f"Error scraping article: {str(e)}")
45
+ return {"error": "Failed to scrape article", "details": str(e)}
46
+
47
+ # Store the article in the database and get the media_id
48
+ try:
49
+ media_id = add_media_to_database(url, title, 'article', content)
50
+ except Exception as e:
51
+ logging.error(f"Error adding article to database: {str(e)}")
52
+ return {"error": "Failed to store article in database", "details": str(e)}
53
+
54
+ # Process and store content
55
+ collection_name = f"article_{media_id}"
56
+ try:
57
+ process_and_store_content(content, collection_name, media_id, title)
58
+ except Exception as e:
59
+ logging.error(f"Error processing and storing content: {str(e)}")
60
+ return {"error": "Failed to process and store content", "details": str(e)}
61
+
62
+ # Perform searches
63
+ try:
64
+ vector_results = vector_search(collection_name, query, k=5)
65
+ fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
66
+ except Exception as e:
67
+ logging.error(f"Error performing searches: {str(e)}")
68
+ return {"error": "Failed to perform searches", "details": str(e)}
69
+
70
+ # Combine results with error handling for missing 'content' key
71
+ all_results = []
72
+ for result in vector_results + fts_results:
73
+ if isinstance(result, dict) and 'content' in result:
74
+ all_results.append(result['content'])
75
+ else:
76
+ logging.warning(f"Unexpected result format: {result}")
77
+ all_results.append(str(result))
78
+
79
+ context = "\n".join(all_results)
80
+
81
+ # Generate answer using the selected API
82
+ try:
83
+ answer = generate_answer(api_choice, context, query)
84
+ except Exception as e:
85
+ logging.error(f"Error generating answer: {str(e)}")
86
+ return {"error": "Failed to generate answer", "details": str(e)}
87
+
88
+ return {
89
+ "answer": answer,
90
+ "context": context
91
+ }
92
+
93
+ except Exception as e:
94
+ logging.error(f"Unexpected error in rag_pipeline: {str(e)}")
95
+ return {"error": "An unexpected error occurred", "details": str(e)}
96
+
97
+
98
+
99
+ # RAG Search with keyword filtering
100
+ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None) -> Dict[str, Any]:
101
+ try:
102
+ # Load embedding provider from config, or fallback to 'openai'
103
+ embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
104
+
105
+ # Log the provider used
106
+ logging.debug(f"Using embedding provider: {embedding_provider}")
107
+
108
+ # Process keywords if provided
109
+ keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else []
110
+ logging.debug(f"enhanced_rag_pipeline - Keywords: {keyword_list}")
111
+
112
+ # Fetch relevant media IDs based on keywords if keywords are provided
113
+ relevant_media_ids = fetch_relevant_media_ids(keyword_list) if keyword_list else None
114
+ logging.debug(f"enhanced_rag_pipeline - relevant media IDs: {relevant_media_ids}")
115
+
116
+ # Perform vector search
117
+ vector_results = perform_vector_search(query, relevant_media_ids)
118
+ logging.debug(f"enhanced_rag_pipeline - Vector search results: {vector_results}")
119
+
120
+ # Perform full-text search
121
+ fts_results = perform_full_text_search(query, relevant_media_ids)
122
+ logging.debug(f"enhanced_rag_pipeline - Full-text search results: {fts_results}")
123
+
124
+ # Combine results
125
+ all_results = vector_results + fts_results
126
+ # FIXME
127
+ if not all_results:
128
+ logging.info(f"No results found. Query: {query}, Keywords: {keywords}")
129
+ return {
130
+ "answer": "I couldn't find any relevant information based on your query and keywords.",
131
+ "context": ""
132
+ }
133
+
134
+ # FIXME - Apply Re-Ranking of results here
135
+ apply_re_ranking = False
136
+ if apply_re_ranking:
137
+ # Implement re-ranking logic here
138
+ pass
139
+ # Extract content from results
140
+ context = "\n".join([result['content'] for result in all_results[:10]]) # Limit to top 10 results
141
+ logging.debug(f"Context length: {len(context)}")
142
+ logging.debug(f"Context: {context[:200]}")
143
+ # Generate answer using the selected API
144
+ answer = generate_answer(api_choice, context, query)
145
+
146
+ return {
147
+ "answer": answer,
148
+ "context": context
149
+ }
150
+ except Exception as e:
151
+ logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
152
+ return {
153
+ "answer": "An error occurred while processing your request.",
154
+ "context": ""
155
+ }
156
+
157
+
158
+ def generate_answer(api_choice: str, context: str, query: str) -> str:
159
+ logging.debug("Entering generate_answer function")
160
+ config = load_comprehensive_config()
161
+ logging.debug(f"Config sections: {config.sections()}")
162
+ prompt = f"Context: {context}\n\nQuestion: {query}"
163
+ if api_choice == "OpenAI":
164
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_openai
165
+ return summarize_with_openai(config['API']['openai_api_key'], prompt, "")
166
+ elif api_choice == "Anthropic":
167
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_anthropic
168
+ return summarize_with_anthropic(config['API']['anthropic_api_key'], prompt, "")
169
+ elif api_choice == "Cohere":
170
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_cohere
171
+ return summarize_with_cohere(config['API']['cohere_api_key'], prompt, "")
172
+ elif api_choice == "Groq":
173
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_groq
174
+ return summarize_with_groq(config['API']['groq_api_key'], prompt, "")
175
+ elif api_choice == "OpenRouter":
176
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_openrouter
177
+ return summarize_with_openrouter(config['API']['openrouter_api_key'], prompt, "")
178
+ elif api_choice == "HuggingFace":
179
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_huggingface
180
+ return summarize_with_huggingface(config['API']['huggingface_api_key'], prompt, "")
181
+ elif api_choice == "DeepSeek":
182
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_deepseek
183
+ return summarize_with_deepseek(config['API']['deepseek_api_key'], prompt, "")
184
+ elif api_choice == "Mistral":
185
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_mistral
186
+ return summarize_with_mistral(config['API']['mistral_api_key'], prompt, "")
187
+ elif api_choice == "Local-LLM":
188
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_local_llm
189
+ return summarize_with_local_llm(config['API']['local_llm_path'], prompt, "")
190
+ elif api_choice == "Llama.cpp":
191
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_llama
192
+ return summarize_with_llama(config['API']['llama_api_key'], prompt, "")
193
+ elif api_choice == "Kobold":
194
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_kobold
195
+ return summarize_with_kobold(config['API']['kobold_api_key'], prompt, "")
196
+ elif api_choice == "Ooba":
197
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_oobabooga
198
+ return summarize_with_oobabooga(config['API']['ooba_api_key'], prompt, "")
199
+ elif api_choice == "TabbyAPI":
200
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_tabbyapi
201
+ return summarize_with_tabbyapi(config['API']['tabby_api_key'], prompt, "")
202
+ elif api_choice == "vLLM":
203
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_vllm
204
+ return summarize_with_vllm(config['API']['vllm_api_key'], prompt, "")
205
+ elif api_choice == "ollama":
206
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_ollama
207
+ return summarize_with_ollama(config['API']['ollama_api_key'], prompt, "")
208
+ else:
209
+ raise ValueError(f"Unsupported API choice: {api_choice}")
210
+
211
+ # Function to preprocess and store all existing content in the database
212
+ def preprocess_all_content():
213
+ unprocessed_media = get_unprocessed_media()
214
+ for row in unprocessed_media:
215
+ media_id = row[0]
216
+ content = row[1]
217
+ media_type = row[2]
218
+ collection_name = f"{media_type}_{media_id}"
219
+ process_and_store_content(content, collection_name, media_id, "")
220
+
221
+
222
+ def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
223
+ all_collections = chroma_client.list_collections()
224
+ vector_results = []
225
+ for collection in all_collections:
226
+ collection_results = vector_search(collection.name, query, k=5)
227
+ filtered_results = [
228
+ result for result in collection_results
229
+ if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids
230
+ ]
231
+ vector_results.extend(filtered_results)
232
+ return vector_results
233
+
234
+
235
+ def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
236
+ fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
237
+ filtered_fts_results = [
238
+ {
239
+ "content": result['content'],
240
+ "metadata": {"media_id": result['id']}
241
+ }
242
+ for result in fts_results
243
+ if relevant_media_ids is None or result['id'] in relevant_media_ids
244
+ ]
245
+ return filtered_fts_results
246
+
247
+
248
+ def fetch_relevant_media_ids(keywords: List[str]) -> List[int]:
249
+ relevant_ids = set()
250
+ try:
251
+ for keyword in keywords:
252
+ media_ids = fetch_keywords_for_media(keyword)
253
+ relevant_ids.update(media_ids)
254
+ except Exception as e:
255
+ logging.error(f"Error fetching relevant media IDs: {str(e)}")
256
+ return list(relevant_ids)
257
+
258
+
259
+ def filter_results_by_keywords(results: List[Dict[str, Any]], keywords: List[str]) -> List[Dict[str, Any]]:
260
+ if not keywords:
261
+ return results
262
+
263
+ filtered_results = []
264
+ for result in results:
265
+ try:
266
+ metadata = result.get('metadata', {})
267
+ if metadata is None:
268
+ logging.warning(f"No metadata found for result: {result}")
269
+ continue
270
+ if not isinstance(metadata, dict):
271
+ logging.warning(f"Unexpected metadata type: {type(metadata)}. Expected dict.")
272
+ continue
273
+
274
+ media_id = metadata.get('media_id')
275
+ if media_id is None:
276
+ logging.warning(f"No media_id found in metadata: {metadata}")
277
+ continue
278
+
279
+ media_keywords = fetch_keywords_for_media(media_id)
280
+ if any(keyword.lower() in [mk.lower() for mk in media_keywords] for keyword in keywords):
281
+ filtered_results.append(result)
282
+ except Exception as e:
283
+ logging.error(f"Error processing result: {result}. Error: {str(e)}")
284
+
285
+ return filtered_results
286
+
287
+ # FIXME: to be implememted
288
+ def extract_media_id_from_result(result: str) -> Optional[int]:
289
+ # Implement this function based on how you store the media_id in your results
290
+ # For example, if it's stored at the beginning of each result:
291
+ try:
292
+ return int(result.split('_')[0])
293
+ except (IndexError, ValueError):
294
+ logging.error(f"Failed to extract media_id from result: {result}")
295
+ return None
296
+
297
+
298
+
299
+
300
+ # Example usage:
301
+ # 1. Initialize the system:
302
+ # create_tables(db) # Ensure FTS tables are set up
303
+ #
304
+ # 2. Create ChromaDB
305
+ # chroma_client = ChromaDBClient()
306
+ #
307
+ # 3. Create Embeddings
308
+ # Store embeddings in ChromaDB
309
+ # preprocess_all_content() or create_embeddings()
310
+ #
311
+ # 4. Perform RAG search across all content:
312
+ # result = rag_search("What are the key points about climate change?")
313
+ # print(result['answer'])
314
+ #
315
+ # (Extra)5. Perform RAG on a specific URL:
316
+ # result = rag_pipeline("https://example.com/article", "What is the main topic of this article?")
317
+ # print(result['answer'])
318
+ #
319
+ ########################################################################################################################
320
+
321
+
322
+ ############################################################################################################
323
+ #
324
+ # ElasticSearch Retriever
325
+
326
+ # https://github.com/langchain-ai/langchain/tree/44e3e2391c48bfd0a8e6a20adde0b6567f4f43c3/templates/rag-elasticsearch
327
+ #
328
+ # https://github.com/langchain-ai/langchain/tree/44e3e2391c48bfd0a8e6a20adde0b6567f4f43c3/templates/rag-self-query
329
+
330
+ #
331
+ # End of RAG_Library_2.py
332
+ ############################################################################################################
App_Function_Libraries/RAG/RAG_QA_Chat.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Podcast_tab.py
2
+ # Description: Gradio UI for ingesting podcasts into the database
3
+ #
4
+ # Imports
5
+ #
6
+ #
7
+ # External Imports
8
+ import json
9
+ import logging
10
+ import tempfile
11
+ from typing import List, Tuple, IO, Union
12
+ #
13
+ # Local Imports
14
+ from App_Function_Libraries.DB.DB_Manager import db, search_db, DatabaseError, get_media_content
15
+ from App_Function_Libraries.RAG.RAG_Libary_2 import generate_answer
16
+ #
17
+ ########################################################################################################################
18
+ #
19
+ # Functions:
20
+
21
+ def rag_qa_chat(message: str, history: List[Tuple[str, str]], context: Union[str, IO[str]], api_choice: str) -> Tuple[List[Tuple[str, str]], str]:
22
+ try:
23
+ # Prepare the context based on the selected source
24
+ if hasattr(context, 'read'):
25
+ # Handle uploaded file
26
+ context_text = context.read()
27
+ if isinstance(context_text, bytes):
28
+ context_text = context_text.decode('utf-8')
29
+ elif isinstance(context, str) and context.startswith("media_id:"):
30
+ # Handle existing file or search result
31
+ media_id = int(context.split(":")[1])
32
+ context_text = get_media_content(media_id) # Implement this function to fetch content from the database
33
+ else:
34
+ context_text = str(context)
35
+
36
+ # Prepare the full context including chat history
37
+ full_context = "\n".join([f"Human: {h[0]}\nAI: {h[1]}" for h in history])
38
+ full_context += f"\n\nContext: {context_text}\n\nHuman: {message}\nAI:"
39
+
40
+ # Generate response using the selected API
41
+ response = generate_answer(api_choice, full_context, message)
42
+
43
+ # Update history
44
+ history.append((message, response))
45
+
46
+ return history, ""
47
+ except DatabaseError as e:
48
+ logging.error(f"Database error in rag_qa_chat: {str(e)}")
49
+ return history, f"An error occurred while accessing the database: {str(e)}"
50
+ except Exception as e:
51
+ logging.error(f"Unexpected error in rag_qa_chat: {str(e)}")
52
+ return history, f"An unexpected error occurred: {str(e)}"
53
+
54
+
55
+
56
+ def save_chat_history(history: List[Tuple[str, str]]) -> str:
57
+ # Save chat history to a file
58
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as temp_file:
59
+ json.dump(history, temp_file)
60
+ return temp_file.name
61
+
62
+
63
+ def load_chat_history(file: IO[str]) -> List[Tuple[str, str]]:
64
+ # Load chat history from a file
65
+ return json.load(file)
66
+
67
+
68
+ def search_database(query: str) -> List[Tuple[int, str]]:
69
+ # Implement database search functionality
70
+ results = search_db(query, ["title", "content"], "", page=1, results_per_page=10)
71
+ return [(result['id'], result['title']) for result in results]
72
+
73
+
74
+ def get_existing_files() -> List[Tuple[int, str]]:
75
+ # Fetch list of existing files from the database
76
+ with db.get_connection() as conn:
77
+ cursor = conn.cursor()
78
+ cursor.execute("SELECT id, title FROM Media ORDER BY title")
79
+ return cursor.fetchall()
80
+
81
+
82
+ #
83
+ # End of RAG_QA_Chat.py
84
+ ########################################################################################################################