ILYA_docs_RAG / app.py
TheDavidYoungblood
Resolved merge conflicts
99defcd
import gradio as gr
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import fitz
from datasets import load_dataset, Dataset
import faiss
import numpy as np
import chromadb
import os
from dotenv import load_dotenv
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
from pastebin_api import get_protected_content
load_dotenv()
# Zilliz connection parameters
ZILLIZ_HOST = os.getenv('ZILLIZ_HOST')
ZILLIZ_PORT = os.getenv('ZILLIZ_PORT')
ZILLIZ_USER = os.getenv('ZILLIZ_USER')
ZILLIZ_PASSWORD = os.getenv('ZILLIZ_PASSWORD')
ZILLIZ_COLLECTION = os.getenv('ZILLIZ_COLLECTION')
# Pastebin API parameters
PERSONA_PASTE_KEY = os.getenv('PERSONA_PASTE_KEY')
# Load Llama 3 model components
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection("user-documents")
def extract_text_from_pdf(pdf_files):
texts = []
for pdf in pdf_files:
doc = fitz.open(pdf.name)
text = ""
for page in doc:
text += page.get_text()
texts.append(text)
return texts
def create_dataset_from_pdfs(pdf_files):
texts = extract_text_from_pdf(pdf_files)
data = {"text": texts}
dataset = Dataset.from_dict(data)
return dataset
def create_and_save_faiss_index(dataset, dataset_path, index_path):
passages = dataset["text"]
passage_embeddings = model.get_encoder()(
tokenizer(passages, return_tensors="pt", padding=True, truncation=True)
).last_hidden_state.mean(dim=1).detach().numpy()
index = faiss.IndexFlatL2(passage_embeddings.shape[1])
index.add(passage_embeddings)
faiss.write_index(index, index_path)
dataset.save_to_disk(dataset_path)
def load_persona_data():
persona_content = get_protected_content(PERSONA_PASTE_KEY)
persona_data = json.loads(persona_content)
return persona_data
def rag_answer(question, pdf_files, use_user_pdfs=False):
if use_user_pdfs:
dataset = create_dataset_from_pdfs(pdf_files)
create_and_save_faiss_index(dataset, "user_dataset_path", "user_index_path")
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq",
index_name="custom",
passages_path="user_dataset_path",
index_path="user_index_path"
)
else:
connections.connect(
host=ZILLIZ_HOST,
port=ZILLIZ_PORT,
user=ZILLIZ_USER,
password=ZILLIZ_PASSWORD,
secure=True
)
collection = Collection(ZILLIZ_COLLECTION)
retriever = RagRetriever(
model.question_encoder,
collection,
model.generator.config.max_combined_length
)
model.retriever = retriever
persona_data = load_persona_data()
prompt_template = f"{persona_data['persona_text']}\n\nUser: {{question}}\nAssistant:"
inputs = tokenizer(
prompt_template.format(question=question),
return_tensors="pt",
truncation=True
)
outputs = model.generate(**inputs)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
def add_pdfs_to_chromadb(pdf_files):
texts = extract_text_from_pdf(pdf_files)
collection.add(
documents=texts,
metadatas=[{"source": pdf.name} for pdf in pdf_files],
ids=[str(i) for i in range(len(pdf_files))]
)
def query_chromadb(question):
results = collection.query(
query_texts=[question],
n_results=2
)
return results
def create_demo():
gr.Markdown(
"""
# RAG-based Conversational Agent with Unique Persona
This application demonstrates a conversational agent that uses Retrieval Augmented Generation (RAG) to answer questions based on a predefined set of PDF documents. The agent's responses are guided by a persona to maintain a consistent tone and style.
You can either query the existing knowledge base or upload your own PDFs for the agent to use. The agent will dynamically select the appropriate content source based on your inputs.
"""
)
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Enter your question")
user_pdfs = gr.File(label="Upload your PDFs", file_count="multiple", file_types=[".pdf"])
submit_btn = gr.Button("Submit")
with gr.Column():
answer = gr.Textbox(label="Agent's Response", interactive=False)
source = gr.Textbox(label="Response Source", interactive=False)
def ask_question(question, user_pdfs):
if user_pdfs:
add_pdfs_to_chromadb(user_pdfs)
response = rag_answer(question, user_pdfs, use_user_pdfs=True)
return response, "User Uploaded PDFs"
else:
response = rag_answer(question, None)
return response, "Predefined Knowledge Base"
submit_btn.click(
ask_question,
inputs=[question, user_pdfs],
outputs=[answer, source]
)
if __name__ == "__main__":
demo = gr.Interface(
fn=create_demo,
title="RAG Conversational Agent",
description="Ask questions to the agent based on predefined or user-uploaded PDF content."
)
demo.launch()