ragDOcs / app.py
izhan001's picture
Create app.py
b3bf1cf verified
raw
history blame
4.25 kB
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import fitz # PyMuPDF for PDF files
from docx import Document
from pptx import Presentation
import gradio as gr
# Initialize SentenceTransformer for embeddings
retrieve = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Initialize empty list for documents and embeddings
documents = []
doc_embeddings = []
index = None # FAISS index will be created only when documents are added
# Function to process PDF files
def process_pdf(file_path):
try:
doc = fitz.open(file_path)
text = ""
for page_num in range(doc.page_count):
text += doc[page_num].get_text()
return text
except Exception as e:
return f"Error reading PDF: {e}"
# Function to process DOCX files
def process_docx(file_path):
try:
doc = Document(file_path)
text = "\n".join([para.text for para in doc.paragraphs])
return text
except Exception as e:
return f"Error reading DOCX: {e}"
# Function to process PPTX files
def process_pptx(file_path):
try:
presentation = Presentation(file_path)
text = ""
for slide in presentation.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text
except Exception as e:
return f"Error reading PPTX: {e}"
# Function to add a document to the FAISS index
def add_to_index(text):
global index, doc_embeddings, documents
if text.strip(): # Only add non-empty documents
embedding = retrieve.encode([text])[0]
doc_embeddings.append(embedding)
documents.append(text)
# Update FAISS index
embeddings_matrix = np.array(doc_embeddings)
index = faiss.IndexFlatL2(embeddings_matrix.shape[1])
index.add(embeddings_matrix)
# Function to load and process a single document
def load_document(file_path):
if file_path.endswith('.pdf'):
text = process_pdf(file_path)
elif file_path.endswith('.docx'):
text = process_docx(file_path)
elif file_path.endswith('.pptx'):
text = process_pptx(file_path)
else:
return "Unsupported file format"
if isinstance(text, str) and "Error" not in text:
add_to_index(text)
return "Document loaded and indexed successfully."
return text # Return error message if processing fails
# Retrieve documents based on the query
def retrieve_docs(query, k=2):
if not index:
return ["Index not initialized. Please upload and process a document first."]
query_embedding = retrieve.encode([query])
distances, indices = index.search(np.array(query_embedding), k)
results = [documents[i] for i in indices[0]]
return results
# Generate a response based on retrieved documents
def generate_response(retrieved_docs):
if retrieved_docs:
context = " ".join(retrieved_docs)
response = f"Generated response based on retrieved docs:\n\n{context[:500]}..." # Placeholder response
return response
return "No relevant documents found to generate a response."
# Gradio function
def rag_application(query, file):
# Load and process the uploaded document if provided
if file:
load_result = load_document(file.name)
if "Error" in load_result:
return load_result, "" # Return error message if document loading failed
# Retrieve relevant documents
retrieved_docs = retrieve_docs(query)
docs_output = "\n".join([f"- {doc[:200]}..." for doc in retrieved_docs]) # Display snippets
# Generate response
response = generate_response(retrieved_docs)
return docs_output, response
# Gradio interface
iface = gr.Interface(
fn=rag_application,
inputs=[
"text", # Query input
"file" # Single file upload
],
outputs=[
"text", # Retrieved documents output
"text" # Generated response output
],
title="RAG Application with Single File Upload",
description="Upload a PDF, DOCX, or PPTX file and ask questions. The RAG application retrieves relevant documents and generates a response."
)
iface.launch()