File size: 5,478 Bytes
73d2546
8e70e09
 
 
 
 
 
73d2546
99b6299
8e70e09
 
99b6299
 
73d2546
8e70e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99b6299
8e70e09
 
 
 
 
 
 
 
 
 
73d2546
 
8e70e09
 
 
 
73d2546
8e70e09
99b6299
8e70e09
 
 
 
99b6299
8e70e09
99b6299
8e70e09
 
 
 
 
 
 
 
 
99b6299
8e70e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73d2546
 
8e70e09
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import fitz
from datasets import load_dataset, Dataset
import faiss
import numpy as np
import chromadb
import os
from dotenv import load_dotenv
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
from pastebin_api import get_protected_content

load_dotenv()

# Zilliz connection parameters
ZILLIZ_HOST = os.getenv('ZILLIZ_HOST')
ZILLIZ_PORT = os.getenv('ZILLIZ_PORT')
ZILLIZ_USER = os.getenv('ZILLIZ_USER')
ZILLIZ_PASSWORD = os.getenv('ZILLIZ_PASSWORD')
ZILLIZ_COLLECTION = os.getenv('ZILLIZ_COLLECTION')

# Pastebin API parameters
PERSONA_PASTE_KEY = os.getenv('PERSONA_PASTE_KEY')

# Load Llama 3 model components
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")

# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection("user-documents")

def extract_text_from_pdf(pdf_files):
    texts = []
    for pdf in pdf_files:
        doc = fitz.open(pdf.name)
        text = ""
        for page in doc:
            text += page.get_text()
        texts.append(text)
    return texts

def create_dataset_from_pdfs(pdf_files):
    texts = extract_text_from_pdf(pdf_files)
    data = {"text": texts}
    dataset = Dataset.from_dict(data)
    return dataset

def create_and_save_faiss_index(dataset, dataset_path, index_path):
    passages = dataset["text"]
    passage_embeddings = model.get_encoder()(
        tokenizer(passages, return_tensors="pt", padding=True, truncation=True)
    ).last_hidden_state.mean(dim=1).detach().numpy()
    
    index = faiss.IndexFlatL2(passage_embeddings.shape[1])
    index.add(passage_embeddings)
    
    faiss.write_index(index, index_path)
    dataset.save_to_disk(dataset_path)

def load_persona_data():
    persona_content = get_protected_content(PERSONA_PASTE_KEY)
    persona_data = json.loads(persona_content)
    return persona_data

def rag_answer(question, pdf_files, use_user_pdfs=False):
    if use_user_pdfs:
        dataset = create_dataset_from_pdfs(pdf_files)
        create_and_save_faiss_index(dataset, "user_dataset_path", "user_index_path")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq",
            index_name="custom",
            passages_path="user_dataset_path",
            index_path="user_index_path"
        )
    else:
        connections.connect(
            host=ZILLIZ_HOST,
            port=ZILLIZ_PORT,
            user=ZILLIZ_USER,
            password=ZILLIZ_PASSWORD,
            secure=True
        )
        collection = Collection(ZILLIZ_COLLECTION)
        retriever = RagRetriever(
            model.question_encoder,
            collection,
            model.generator.config.max_combined_length
        )
    
    model.retriever = retriever
    
    persona_data = load_persona_data()
    prompt_template = f"{persona_data['persona_text']}\n\nUser: {{question}}\nAssistant:"
    
    inputs = tokenizer(
        prompt_template.format(question=question), 
        return_tensors="pt", 
        truncation=True
    )
    outputs = model.generate(**inputs)
    
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

def add_pdfs_to_chromadb(pdf_files):
    texts = extract_text_from_pdf(pdf_files)
    collection.add(
        documents=texts,
        metadatas=[{"source": pdf.name} for pdf in pdf_files],
        ids=[str(i) for i in range(len(pdf_files))]
    )

def query_chromadb(question):
    results = collection.query(
        query_texts=[question],
        n_results=2
    )
    return results

def create_demo():
    gr.Markdown(
        """
        # RAG-based Conversational Agent with Unique Persona
        
        This application demonstrates a conversational agent that uses Retrieval Augmented Generation (RAG) to answer questions based on a predefined set of PDF documents. The agent's responses are guided by a persona to maintain a consistent tone and style.
        
        You can either query the existing knowledge base or upload your own PDFs for the agent to use. The agent will dynamically select the appropriate content source based on your inputs.
        """
    )
    
    with gr.Row():
        with gr.Column():
            question = gr.Textbox(label="Enter your question")
            user_pdfs = gr.File(label="Upload your PDFs", file_count="multiple", file_types=[".pdf"])
            submit_btn = gr.Button("Submit")
        
        with gr.Column():
            answer = gr.Textbox(label="Agent's Response", interactive=False)
            source = gr.Textbox(label="Response Source", interactive=False)
    
    def ask_question(question, user_pdfs):
        if user_pdfs:
            add_pdfs_to_chromadb(user_pdfs)
            response = rag_answer(question, user_pdfs, use_user_pdfs=True)
            return response, "User Uploaded PDFs"
        else:
            response = rag_answer(question, None)
            return response, "Predefined Knowledge Base"
    
    submit_btn.click(
        ask_question, 
        inputs=[question, user_pdfs], 
        outputs=[answer, source]
    )

if __name__ == "__main__":
    demo = gr.Interface(
        fn=create_demo,
        title="RAG Conversational Agent",
        description="Ask questions to the agent based on predefined or user-uploaded PDF content."
    )
    demo.launch()