Spaces:
Runtime error
Runtime error
File size: 5,478 Bytes
73d2546 8e70e09 73d2546 99b6299 8e70e09 99b6299 73d2546 8e70e09 99b6299 8e70e09 73d2546 8e70e09 73d2546 8e70e09 99b6299 8e70e09 99b6299 8e70e09 99b6299 8e70e09 99b6299 8e70e09 73d2546 8e70e09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import fitz
from datasets import load_dataset, Dataset
import faiss
import numpy as np
import chromadb
import os
from dotenv import load_dotenv
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
from pastebin_api import get_protected_content
load_dotenv()
# Zilliz connection parameters
ZILLIZ_HOST = os.getenv('ZILLIZ_HOST')
ZILLIZ_PORT = os.getenv('ZILLIZ_PORT')
ZILLIZ_USER = os.getenv('ZILLIZ_USER')
ZILLIZ_PASSWORD = os.getenv('ZILLIZ_PASSWORD')
ZILLIZ_COLLECTION = os.getenv('ZILLIZ_COLLECTION')
# Pastebin API parameters
PERSONA_PASTE_KEY = os.getenv('PERSONA_PASTE_KEY')
# Load Llama 3 model components
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection("user-documents")
def extract_text_from_pdf(pdf_files):
texts = []
for pdf in pdf_files:
doc = fitz.open(pdf.name)
text = ""
for page in doc:
text += page.get_text()
texts.append(text)
return texts
def create_dataset_from_pdfs(pdf_files):
texts = extract_text_from_pdf(pdf_files)
data = {"text": texts}
dataset = Dataset.from_dict(data)
return dataset
def create_and_save_faiss_index(dataset, dataset_path, index_path):
passages = dataset["text"]
passage_embeddings = model.get_encoder()(
tokenizer(passages, return_tensors="pt", padding=True, truncation=True)
).last_hidden_state.mean(dim=1).detach().numpy()
index = faiss.IndexFlatL2(passage_embeddings.shape[1])
index.add(passage_embeddings)
faiss.write_index(index, index_path)
dataset.save_to_disk(dataset_path)
def load_persona_data():
persona_content = get_protected_content(PERSONA_PASTE_KEY)
persona_data = json.loads(persona_content)
return persona_data
def rag_answer(question, pdf_files, use_user_pdfs=False):
if use_user_pdfs:
dataset = create_dataset_from_pdfs(pdf_files)
create_and_save_faiss_index(dataset, "user_dataset_path", "user_index_path")
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq",
index_name="custom",
passages_path="user_dataset_path",
index_path="user_index_path"
)
else:
connections.connect(
host=ZILLIZ_HOST,
port=ZILLIZ_PORT,
user=ZILLIZ_USER,
password=ZILLIZ_PASSWORD,
secure=True
)
collection = Collection(ZILLIZ_COLLECTION)
retriever = RagRetriever(
model.question_encoder,
collection,
model.generator.config.max_combined_length
)
model.retriever = retriever
persona_data = load_persona_data()
prompt_template = f"{persona_data['persona_text']}\n\nUser: {{question}}\nAssistant:"
inputs = tokenizer(
prompt_template.format(question=question),
return_tensors="pt",
truncation=True
)
outputs = model.generate(**inputs)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
def add_pdfs_to_chromadb(pdf_files):
texts = extract_text_from_pdf(pdf_files)
collection.add(
documents=texts,
metadatas=[{"source": pdf.name} for pdf in pdf_files],
ids=[str(i) for i in range(len(pdf_files))]
)
def query_chromadb(question):
results = collection.query(
query_texts=[question],
n_results=2
)
return results
def create_demo():
gr.Markdown(
"""
# RAG-based Conversational Agent with Unique Persona
This application demonstrates a conversational agent that uses Retrieval Augmented Generation (RAG) to answer questions based on a predefined set of PDF documents. The agent's responses are guided by a persona to maintain a consistent tone and style.
You can either query the existing knowledge base or upload your own PDFs for the agent to use. The agent will dynamically select the appropriate content source based on your inputs.
"""
)
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Enter your question")
user_pdfs = gr.File(label="Upload your PDFs", file_count="multiple", file_types=[".pdf"])
submit_btn = gr.Button("Submit")
with gr.Column():
answer = gr.Textbox(label="Agent's Response", interactive=False)
source = gr.Textbox(label="Response Source", interactive=False)
def ask_question(question, user_pdfs):
if user_pdfs:
add_pdfs_to_chromadb(user_pdfs)
response = rag_answer(question, user_pdfs, use_user_pdfs=True)
return response, "User Uploaded PDFs"
else:
response = rag_answer(question, None)
return response, "Predefined Knowledge Base"
submit_btn.click(
ask_question,
inputs=[question, user_pdfs],
outputs=[answer, source]
)
if __name__ == "__main__":
demo = gr.Interface(
fn=create_demo,
title="RAG Conversational Agent",
description="Ask questions to the agent based on predefined or user-uploaded PDF content."
)
demo.launch() |