Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
import fitz | |
from datasets import load_dataset, Dataset | |
import faiss | |
import numpy as np | |
import chromadb | |
import os | |
from dotenv import load_dotenv | |
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType | |
from pastebin_api import get_protected_content | |
load_dotenv() | |
# Zilliz connection parameters | |
ZILLIZ_HOST = os.getenv('ZILLIZ_HOST') | |
ZILLIZ_PORT = os.getenv('ZILLIZ_PORT') | |
ZILLIZ_USER = os.getenv('ZILLIZ_USER') | |
ZILLIZ_PASSWORD = os.getenv('ZILLIZ_PASSWORD') | |
ZILLIZ_COLLECTION = os.getenv('ZILLIZ_COLLECTION') | |
# Pastebin API parameters | |
PERSONA_PASTE_KEY = os.getenv('PERSONA_PASTE_KEY') | |
# Load Llama 3 model components | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq") | |
# Initialize ChromaDB client | |
client = chromadb.Client() | |
collection = client.create_collection("user-documents") | |
def extract_text_from_pdf(pdf_files): | |
texts = [] | |
for pdf in pdf_files: | |
doc = fitz.open(pdf.name) | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
texts.append(text) | |
return texts | |
def create_dataset_from_pdfs(pdf_files): | |
texts = extract_text_from_pdf(pdf_files) | |
data = {"text": texts} | |
dataset = Dataset.from_dict(data) | |
return dataset | |
def create_and_save_faiss_index(dataset, dataset_path, index_path): | |
passages = dataset["text"] | |
passage_embeddings = model.get_encoder()( | |
tokenizer(passages, return_tensors="pt", padding=True, truncation=True) | |
).last_hidden_state.mean(dim=1).detach().numpy() | |
index = faiss.IndexFlatL2(passage_embeddings.shape[1]) | |
index.add(passage_embeddings) | |
faiss.write_index(index, index_path) | |
dataset.save_to_disk(dataset_path) | |
def load_persona_data(): | |
persona_content = get_protected_content(PERSONA_PASTE_KEY) | |
persona_data = json.loads(persona_content) | |
return persona_data | |
def rag_answer(question, pdf_files, use_user_pdfs=False): | |
if use_user_pdfs: | |
dataset = create_dataset_from_pdfs(pdf_files) | |
create_and_save_faiss_index(dataset, "user_dataset_path", "user_index_path") | |
retriever = RagRetriever.from_pretrained( | |
"facebook/rag-sequence-nq", | |
index_name="custom", | |
passages_path="user_dataset_path", | |
index_path="user_index_path" | |
) | |
else: | |
connections.connect( | |
host=ZILLIZ_HOST, | |
port=ZILLIZ_PORT, | |
user=ZILLIZ_USER, | |
password=ZILLIZ_PASSWORD, | |
secure=True | |
) | |
collection = Collection(ZILLIZ_COLLECTION) | |
retriever = RagRetriever( | |
model.question_encoder, | |
collection, | |
model.generator.config.max_combined_length | |
) | |
model.retriever = retriever | |
persona_data = load_persona_data() | |
prompt_template = f"{persona_data['persona_text']}\n\nUser: {{question}}\nAssistant:" | |
inputs = tokenizer( | |
prompt_template.format(question=question), | |
return_tensors="pt", | |
truncation=True | |
) | |
outputs = model.generate(**inputs) | |
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
def add_pdfs_to_chromadb(pdf_files): | |
texts = extract_text_from_pdf(pdf_files) | |
collection.add( | |
documents=texts, | |
metadatas=[{"source": pdf.name} for pdf in pdf_files], | |
ids=[str(i) for i in range(len(pdf_files))] | |
) | |
def query_chromadb(question): | |
results = collection.query( | |
query_texts=[question], | |
n_results=2 | |
) | |
return results | |
def create_demo(): | |
gr.Markdown( | |
""" | |
# RAG-based Conversational Agent with Unique Persona | |
This application demonstrates a conversational agent that uses Retrieval Augmented Generation (RAG) to answer questions based on a predefined set of PDF documents. The agent's responses are guided by a persona to maintain a consistent tone and style. | |
You can either query the existing knowledge base or upload your own PDFs for the agent to use. The agent will dynamically select the appropriate content source based on your inputs. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
question = gr.Textbox(label="Enter your question") | |
user_pdfs = gr.File(label="Upload your PDFs", file_count="multiple", file_types=[".pdf"]) | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
answer = gr.Textbox(label="Agent's Response", interactive=False) | |
source = gr.Textbox(label="Response Source", interactive=False) | |
def ask_question(question, user_pdfs): | |
if user_pdfs: | |
add_pdfs_to_chromadb(user_pdfs) | |
response = rag_answer(question, user_pdfs, use_user_pdfs=True) | |
return response, "User Uploaded PDFs" | |
else: | |
response = rag_answer(question, None) | |
return response, "Predefined Knowledge Base" | |
submit_btn.click( | |
ask_question, | |
inputs=[question, user_pdfs], | |
outputs=[answer, source] | |
) | |
if __name__ == "__main__": | |
demo = gr.Interface( | |
fn=create_demo, | |
title="RAG Conversational Agent", | |
description="Ask questions to the agent based on predefined or user-uploaded PDF content." | |
) | |
demo.launch() |