Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import lancedb | |
from sentence_transformers import SentenceTransformer | |
from dotenv import load_dotenv | |
from typing import List | |
from PIL import Image | |
import base64 | |
import io | |
import time | |
from collections import namedtuple | |
import pandas as pd | |
import concurrent.futures | |
from varag.rag import SimpleRAG, VisionRAG, ColpaliRAG, HybridColpaliRAG | |
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor | |
from varag.chunking import FixedTokenChunker | |
from varag.utils import get_model_colpali | |
import argparse | |
import spaces | |
import torch | |
from docling.document_converter import DocumentConverter | |
load_dotenv() | |
# Initialize shared database | |
shared_db = lancedb.connect("~/rag_demo_db") | |
# Initialize embedding models | |
# text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True) | |
text_embedding_model = SentenceTransformer( | |
"BAAI/bge-base-en", trust_remote_code=True | |
) | |
# text_embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5", trust_remote_code=True) | |
# text_embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5", trust_remote_code=True) | |
image_embedding_model = SentenceTransformer( | |
"jinaai/jina-clip-v1", trust_remote_code=True | |
) | |
colpali_model, colpali_processor = get_model_colpali("vidore/colpali-v1.2") | |
converter = DocumentConverter() | |
# Initialize RAG instances | |
simple_rag = SimpleRAG( | |
text_embedding_model=text_embedding_model, db=shared_db, table_name="simpleDemo" | |
) | |
vision_rag = VisionRAG( | |
image_embedding_model=image_embedding_model, db=shared_db, table_name="visionDemo" | |
) | |
colpali_rag = ColpaliRAG( | |
colpali_model=colpali_model, | |
colpali_processor=colpali_processor, | |
db=shared_db, | |
table_name="colpaliDemo", | |
) | |
hybrid_rag = HybridColpaliRAG( | |
colpali_model=colpali_model, | |
colpali_processor=colpali_processor, | |
image_embedding_model=image_embedding_model, | |
db=shared_db, | |
table_name="hybridDemo", | |
) | |
IngestResult = namedtuple("IngestResult", ["status_text", "progress_table"]) | |
# @spaces.GPU(duration=120) | |
# def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()): | |
# file_paths = [pdf_file.name for pdf_file in pdf_files] | |
# total_start_time = time.time() | |
# progress_data = [] | |
# # SimpleRAG | |
# yield IngestResult( | |
# status_text="Starting SimpleRAG ingestion...\n", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
# start_time = time.time() | |
# simple_rag.index( | |
# file_paths, | |
# recursive=False, | |
# chunking_strategy=FixedTokenChunker(chunk_size=chunk_size), | |
# metadata={"source": "gradio_upload"}, | |
# overwrite=True, | |
# verbose=True, | |
# ocr=use_ocr, | |
# ) | |
# simple_time = time.time() - start_time | |
# progress_data.append( | |
# {"Technique": "SimpleRAG", "Time Taken (s)": f"{simple_time:.2f}"} | |
# ) | |
# yield IngestResult( | |
# status_text=f"SimpleRAG ingestion complete. Time taken: {simple_time:.2f} seconds\n\n", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
# # progress(0.25, desc="SimpleRAG complete") | |
# # VisionRAG | |
# yield IngestResult( | |
# status_text="Starting VisionRAG ingestion...\n", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
# start_time = time.time() | |
# vision_rag.index(file_paths, overwrite=False, recursive=False, verbose=True) | |
# vision_time = time.time() - start_time | |
# progress_data.append( | |
# {"Technique": "VisionRAG", "Time Taken (s)": f"{vision_time:.2f}"} | |
# ) | |
# yield IngestResult( | |
# status_text=f"VisionRAG ingestion complete. Time taken: {vision_time:.2f} seconds\n\n", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
# # progress(0.5, desc="VisionRAG complete") | |
# # ColpaliRAG | |
# yield IngestResult( | |
# status_text="Starting ColpaliRAG ingestion...\n", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
# start_time = time.time() | |
# colpali_rag.index(file_paths, overwrite=False, recursive=False, verbose=True) | |
# colpali_time = time.time() - start_time | |
# progress_data.append( | |
# {"Technique": "ColpaliRAG", "Time Taken (s)": f"{colpali_time:.2f}"} | |
# ) | |
# yield IngestResult( | |
# status_text=f"ColpaliRAG ingestion complete. Time taken: {colpali_time:.2f} seconds\n\n", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
# # progress(0.75, desc="ColpaliRAG complete") | |
# # HybridColpaliRAG | |
# yield IngestResult( | |
# status_text="Starting HybridColpaliRAG ingestion...\n", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
# start_time = time.time() | |
# hybrid_rag.index(file_paths, overwrite=False, recursive=False, verbose=True) | |
# hybrid_time = time.time() - start_time | |
# progress_data.append( | |
# {"Technique": "HybridColpaliRAG", "Time Taken (s)": f"{hybrid_time:.2f}"} | |
# ) | |
# yield IngestResult( | |
# status_text=f"HybridColpaliRAG ingestion complete. Time taken: {hybrid_time:.2f} seconds\n\n", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
# # progress(1.0, desc="HybridColpaliRAG complete") | |
# total_time = time.time() - total_start_time | |
# progress_data.append({"Technique": "Total", "Time Taken (s)": f"{total_time:.2f}"}) | |
# yield IngestResult( | |
# status_text=f"Total ingestion time: {total_time:.2f} seconds", | |
# progress_table=pd.DataFrame(progress_data), | |
# ) | |
def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()): | |
file_paths = [pdf_file.name for pdf_file in pdf_files] | |
total_start_time = time.time() | |
progress_data = [] | |
def ingest_simple_rag(): | |
yield IngestResult( | |
status_text="Starting SimpleRAG ingestion...\n", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
start_time = time.time() | |
simple_rag.index( | |
file_paths, | |
recursive=False, | |
chunking_strategy=FixedTokenChunker(chunk_size=chunk_size), | |
metadata={"source": "gradio_upload"}, | |
overwrite=True, | |
verbose=True, | |
ocr=use_ocr, | |
) | |
simple_time = time.time() - start_time | |
progress_data.append( | |
{"Technique": "SimpleRAG", "Time Taken (s)": f"{simple_time:.2f}"} | |
) | |
yield IngestResult( | |
status_text=f"SimpleRAG ingestion complete. Time taken: {simple_time:.2f} seconds\n\n", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
def ingest_vision_rag(): | |
yield IngestResult( | |
status_text="Starting VisionRAG ingestion...\n", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
start_time = time.time() | |
vision_rag.index(file_paths, overwrite=False, recursive=False, verbose=True) | |
vision_time = time.time() - start_time | |
progress_data.append( | |
{"Technique": "VisionRAG", "Time Taken (s)": f"{vision_time:.2f}"} | |
) | |
yield IngestResult( | |
status_text=f"VisionRAG ingestion complete. Time taken: {vision_time:.2f} seconds\n\n", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
def ingest_colpali_rag(): | |
yield IngestResult( | |
status_text="Starting ColpaliRAG ingestion...\n", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
start_time = time.time() | |
colpali_rag.index(file_paths, overwrite=False, recursive=False, verbose=True) | |
colpali_time = time.time() - start_time | |
progress_data.append( | |
{"Technique": "ColpaliRAG", "Time Taken (s)": f"{colpali_time:.2f}"} | |
) | |
yield IngestResult( | |
status_text=f"ColpaliRAG ingestion complete. Time taken: {colpali_time:.2f} seconds\n\n", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
def ingest_hybrid_rag(): | |
yield IngestResult( | |
status_text="Starting HybridColpaliRAG ingestion...\n", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
start_time = time.time() | |
hybrid_rag.index(file_paths, overwrite=False, recursive=False, verbose=True) | |
hybrid_time = time.time() - start_time | |
progress_data.append( | |
{"Technique": "HybridColpaliRAG", "Time Taken (s)": f"{hybrid_time:.2f}"} | |
) | |
yield IngestResult( | |
status_text=f"HybridColpaliRAG ingestion complete. Time taken: {hybrid_time:.2f} seconds\n\n", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
# Call each ingestion function | |
yield from ingest_simple_rag() | |
yield from ingest_vision_rag() | |
yield from ingest_colpali_rag() | |
yield from ingest_hybrid_rag() | |
total_time = time.time() - total_start_time | |
progress_data.append({"Technique": "Total", "Time Taken (s)": f"{total_time:.2f}"}) | |
yield IngestResult( | |
status_text=f"Total ingestion time: {total_time:.2f} seconds", | |
progress_table=pd.DataFrame(progress_data), | |
) | |
def retrieve_data(query, top_k, sequential=False): | |
results = {} | |
timings = {} | |
def retrieve_simple(): | |
start_time = time.time() | |
simple_results = simple_rag.search(query, k=top_k) | |
print(simple_results) | |
simple_context = [] | |
for i, r in enumerate(simple_results, 1): | |
context_piece = f"Result {i}:\n" | |
context_piece += f"Source: {r.get('document_name', 'Unknown')}\n" | |
context_piece += f"Chunk Index: {r.get('chunk_index', 'Unknown')}\n" | |
context_piece += f"Content:\n{r['text']}\n" | |
context_piece += "-" * 40 + "\n" # Separator | |
simple_context.append(context_piece) | |
simple_context = "\n".join(simple_context) | |
end_time = time.time() | |
return "SimpleRAG", simple_context, end_time - start_time | |
def retrieve_vision(): | |
start_time = time.time() | |
vision_results = vision_rag.search(query, k=top_k) | |
vision_images = [r["image"] for r in vision_results] | |
end_time = time.time() | |
return "VisionRAG", vision_images, end_time - start_time | |
def retrieve_colpali(): | |
start_time = time.time() | |
colpali_results = colpali_rag.search(query, k=top_k) | |
colpali_images = [r["image"] for r in colpali_results] | |
end_time = time.time() | |
return "ColpaliRAG", colpali_images, end_time - start_time | |
def retrieve_hybrid(): | |
start_time = time.time() | |
hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True) | |
hybrid_images = [r["image"] for r in hybrid_results] | |
end_time = time.time() | |
return "HybridColpaliRAG", hybrid_images, end_time - start_time | |
retrieval_functions = [ | |
retrieve_simple, | |
retrieve_vision, | |
retrieve_colpali, | |
retrieve_hybrid, | |
] | |
if sequential: | |
for func in retrieval_functions: | |
rag_type, content, timing = func() | |
results[rag_type] = content | |
timings[rag_type] = timing | |
else: | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future_results = [executor.submit(func) for func in retrieval_functions] | |
for future in concurrent.futures.as_completed(future_results): | |
rag_type, content, timing = future.result() | |
results[rag_type] = content | |
timings[rag_type] = timing | |
return results, timings | |
# @spaces.GPU | |
# def query_data(query, retrieved_results): | |
# results = {} | |
# # SimpleRAG | |
# simple_context = retrieved_results["SimpleRAG"] | |
# simple_response = llm.query( | |
# context=simple_context, | |
# system_prompt="Given the below information answer the questions", | |
# query=query, | |
# ) | |
# results["SimpleRAG"] = {"response": simple_response, "context": simple_context} | |
# # VisionRAG | |
# vision_images = retrieved_results["VisionRAG"] | |
# vision_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join( | |
# [f"Image {i+1}" for i in range(len(vision_images))] | |
# ) | |
# vision_response = vlm.query(vision_context, vision_images, max_tokens=500) | |
# results["VisionRAG"] = { | |
# "response": vision_response, | |
# "context": vision_context, | |
# "images": vision_images, | |
# } | |
# # ColpaliRAG | |
# colpali_images = retrieved_results["ColpaliRAG"] | |
# colpali_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join( | |
# [f"Image {i+1}" for i in range(len(colpali_images))] | |
# ) | |
# colpali_response = vlm.query(colpali_context, colpali_images, max_tokens=500) | |
# results["ColpaliRAG"] = { | |
# "response": colpali_response, | |
# "context": colpali_context, | |
# "images": colpali_images, | |
# } | |
# # HybridColpaliRAG | |
# hybrid_images = retrieved_results["HybridColpaliRAG"] | |
# hybrid_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join( | |
# [f"Image {i+1}" for i in range(len(hybrid_images))] | |
# ) | |
# hybrid_response = vlm.query(hybrid_context, hybrid_images, max_tokens=500) | |
# results["HybridColpaliRAG"] = { | |
# "response": hybrid_response, | |
# "context": hybrid_context, | |
# "images": hybrid_images, | |
# } | |
# return results | |
def update_api_key(api_key): | |
os.environ["OPENAI_API_KEY"] = api_key | |
return "API key updated successfully." | |
def change_table(simple_table, vision_table, colpali_table, hybrid_table): | |
simple_rag.change_table(simple_table) | |
vision_rag.change_table(vision_table) | |
colpali_rag.change_table(colpali_table) | |
hybrid_rag.change_table(hybrid_table) | |
return "Table names updated successfully." | |
def gradio_interface(): | |
with gr.Blocks( | |
theme=gr.themes.Monochrome(radius_size=gr.themes.sizes.radius_none) | |
) as demo: | |
gr.Markdown( | |
""" | |
# 👁️👁️ Vision RAG Playground | |
### Explore and Compare Vision-Augmented Retrieval Techniques | |
Built on [VARAG](https://github.com/adithya-s-k/VARAG) - Vision-Augmented Retrieval and Generation | |
**[⭐ Star the Repository](https://github.com/adithya-s-k/VARAG)** to support the project! | |
1. **Simple RAG**: Text-based retrieval with OCR support for scanned documents. | |
2. **Vision RAG**: Combines text and image retrieval using cross-modal embeddings. | |
3. **ColPali RAG**: Embeds entire document pages as images for layout-aware retrieval. | |
4. **Hybrid ColPali RAG**: Two-stage retrieval combining image embeddings and ColPali's token-level matching. | |
""" | |
) | |
with gr.Tab("Ingest Data"): | |
gr.Markdown( | |
""" | |
## ⚠️ Important Note on Data Ingestion | |
This Space has a maximum GPU-enabled time of 120 seconds. It's recommended to try ingesting only 1 or 2 pdfs at a time. | |
If you want to ingest a larger amount of data, please try it out in a Google Colab notebook: | |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adithya-s-k/VARAG/blob/main/docs/demo.ipynb) | |
""" | |
) | |
pdf_input = gr.File( | |
label="Upload PDF(s)", file_count="multiple", file_types=["pdf"] | |
) | |
use_ocr = gr.Checkbox(label="Use OCR (for SimpleRAG)") | |
chunk_size = gr.Slider( | |
50, 5000, value=200, step=10, label="Chunk Size (for SimpleRAG)" | |
) | |
ingest_button = gr.Button("Ingest PDFs") | |
ingest_output = gr.Markdown( | |
label="Ingestion Status :", | |
) | |
progress_table = gr.DataFrame( | |
label="Ingestion Progress", headers=["Technique", "Time Taken (s)"] | |
) | |
with gr.Tab("Retrieve and Query Data"): | |
query_input = gr.Textbox(label="Enter your query") | |
top_k_slider = gr.Slider(1, 10, value=3, step=1, label="Top K Results") | |
sequential_checkbox = gr.Checkbox(label="Sequential Retrieval", value=False) | |
retrieve_button = gr.Button("Retrieve") | |
query_button = gr.Button("Query") | |
retrieval_timing = gr.DataFrame( | |
label="Retrieval Timings", headers=["RAG Type", "Time (s)"] | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("SimpleRAG", open=True): | |
simple_content = gr.Textbox( | |
label="SimpleRAG Content", lines=10, max_lines=10 | |
) | |
simple_response = gr.Markdown(label="SimpleRAG Response") | |
with gr.Column(): | |
with gr.Accordion("VisionRAG", open=True): | |
vision_gallery = gr.Gallery(label="VisionRAG Images") | |
vision_response = gr.Markdown(label="VisionRAG Response") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("ColpaliRAG", open=True): | |
colpali_gallery = gr.Gallery(label="ColpaliRAG Images") | |
colpali_response = gr.Markdown(label="ColpaliRAG Response") | |
with gr.Column(): | |
with gr.Accordion("HybridColpaliRAG", open=True): | |
hybrid_gallery = gr.Gallery(label="HybridColpaliRAG Images") | |
hybrid_response = gr.Markdown(label="HybridColpaliRAG Response") | |
with gr.Tab("Settings"): | |
api_key_input = gr.Textbox(label="OpenAI API Key", type="password") | |
update_api_button = gr.Button("Update API Key") | |
api_update_status = gr.Textbox(label="API Update Status") | |
simple_table_input = gr.Textbox( | |
label="SimpleRAG Table Name", value="simpleDemo" | |
) | |
vision_table_input = gr.Textbox( | |
label="VisionRAG Table Name", value="visionDemo" | |
) | |
colpali_table_input = gr.Textbox( | |
label="ColpaliRAG Table Name", value="colpaliDemo" | |
) | |
hybrid_table_input = gr.Textbox( | |
label="HybridColpaliRAG Table Name", value="hybridDemo" | |
) | |
update_table_button = gr.Button("Update Table Names") | |
table_update_status = gr.Textbox(label="Table Update Status") | |
retrieved_results = gr.State({}) | |
def update_retrieval_results(query, top_k, sequential): | |
results, timings = retrieve_data(query, top_k, sequential) | |
timing_df = pd.DataFrame( | |
list(timings.items()), columns=["RAG Type", "Time (s)"] | |
) | |
return ( | |
results["SimpleRAG"], | |
results["VisionRAG"], | |
results["ColpaliRAG"], | |
results["HybridColpaliRAG"], | |
timing_df, | |
results, | |
) | |
retrieve_button.click( | |
update_retrieval_results, | |
inputs=[query_input, top_k_slider, sequential_checkbox], | |
outputs=[ | |
simple_content, | |
vision_gallery, | |
colpali_gallery, | |
hybrid_gallery, | |
retrieval_timing, | |
retrieved_results, | |
], | |
) | |
# def update_query_results(query, retrieved_results): | |
# results = query_data(query, retrieved_results) | |
# return ( | |
# results["SimpleRAG"]["response"], | |
# results["VisionRAG"]["response"], | |
# results["ColpaliRAG"]["response"], | |
# results["HybridColpaliRAG"]["response"], | |
# ) | |
# query_button.click( | |
# update_query_results, | |
# inputs=[query_input, retrieved_results], | |
# outputs=[ | |
# simple_response, | |
# vision_response, | |
# colpali_response, | |
# hybrid_response, | |
# ], | |
# ) | |
ingest_button.click( | |
ingest_data, | |
inputs=[pdf_input, use_ocr, chunk_size], | |
outputs=[ingest_output, progress_table], | |
) | |
update_api_button.click( | |
update_api_key, inputs=[api_key_input], outputs=api_update_status | |
) | |
update_table_button.click( | |
change_table, | |
inputs=[ | |
simple_table_input, | |
vision_table_input, | |
colpali_table_input, | |
hybrid_table_input, | |
], | |
outputs=table_update_status, | |
) | |
return demo | |
# Parse command-line arguments | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="VisionRAG Gradio App") | |
parser.add_argument( | |
"--share", action="store_true", help="Enable Gradio share feature" | |
) | |
return parser.parse_args() | |
# Launch the app | |
if __name__ == "__main__": | |
args = parse_args() | |
app = gradio_interface() | |
app.launch(share=args.share) | |