Spaces:
Runtime error
Runtime error
import tqdm | |
from PIL import Image | |
import hashlib | |
import torch | |
import fitz | |
import gradio as gr | |
import os | |
from transformers import AutoModel, AutoTokenizer | |
import numpy as np | |
import json | |
import spaces | |
cache_dir = 'kb_cache' | |
os.makedirs(cache_dir, exist_ok=True) | |
def get_image_md5(img: Image.Image): | |
img_byte_array = img.tobytes() | |
hash_md5 = hashlib.md5() | |
hash_md5.update(img_byte_array) | |
hex_digest = hash_md5.hexdigest() | |
return hex_digest | |
def calculate_md5_from_binary(binary_data): | |
hash_md5 = hashlib.md5() | |
hash_md5.update(binary_data) | |
return hash_md5.hexdigest() | |
def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()): | |
if pdf_file_binary is None: | |
return "No PDF file uploaded." | |
global model, tokenizer | |
model.eval() | |
knowledge_base_name = calculate_md5_from_binary(pdf_file_binary) | |
this_cache_dir = os.path.join(cache_dir, knowledge_base_name) | |
os.makedirs(this_cache_dir, exist_ok=True) | |
with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file: | |
file.write(pdf_file_binary) | |
dpi = 200 | |
doc = fitz.open("pdf", pdf_file_binary) | |
reps_list = [] | |
images = [] | |
image_md5s = [] | |
for page in progress.tqdm(doc): | |
pix = page.get_pixmap(dpi=dpi) | |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
image_md5 = get_image_md5(image) | |
image_md5s.append(image_md5) | |
with torch.no_grad(): | |
reps = model(text=[''], image=[image], tokenizer=tokenizer).reps | |
reps_list.append(reps.squeeze(0).cpu().numpy()) | |
images.append(image) | |
for idx in range(len(images)): | |
image = images[idx] | |
image_md5 = image_md5s[idx] | |
cache_image_path = os.path.join(this_cache_dir, f"{image_md5}.png") | |
image.save(cache_image_path) | |
np.save(os.path.join(this_cache_dir, f"reps.npy"), reps_list) | |
with open(os.path.join(this_cache_dir, f"md5s.txt"), 'w') as f: | |
for item in image_md5s: | |
f.write(item+'\n') | |
return "PDF processed successfully!" | |
def retrieve_gradio(pdf_file_binary, query: str, topk: int): | |
global model, tokenizer | |
model.eval() | |
if pdf_file_binary is None: | |
return "No PDF file uploaded." | |
knowledge_base_name = calculate_md5_from_binary(pdf_file_binary) | |
target_cache_dir = os.path.join(cache_dir, knowledge_base_name) | |
if not os.path.exists(target_cache_dir): | |
return None | |
md5s = [] | |
with open(os.path.join(target_cache_dir, f"md5s.txt"), 'r') as f: | |
for line in f: | |
md5s.append(line.rstrip('\n')) | |
doc_reps = np.load(os.path.join(target_cache_dir, f"reps.npy")) | |
query_with_instruction = "Represent this query for retrieving relevant document: " + query | |
with torch.no_grad(): | |
query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu() | |
query_md5 = hashlib.md5(query.encode()).hexdigest() | |
doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0) | |
similarities = torch.matmul(query_rep, doc_reps_cat.T) | |
topk_values, topk_doc_ids = torch.topk(similarities, k=topk) | |
images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids.cpu().numpy()] | |
return images_topk | |
with gr.Blocks() as app: | |
gr.Markdown("# MiniCPMV-RAG-PDFQA") | |
with gr.Row(): | |
file_input = gr.File(type="binary", label="Upload PDF") | |
process_button = gr.Button("Process PDF") | |
process_button.click(add_pdf_gradio, inputs=[file_input], outputs="text") | |
with gr.Row(): | |
query_input = gr.Text(label="Your Question") | |
topk_input = gr.Number(value=5, minimum=1, maximum=10, step=1, label="Number of pages to retrieve") | |
retrieve_button = gr.Button("Retrieve Pages") | |
images_output = gr.Gallery(label="Retrieved Pages") | |
retrieve_button.click(retrieve_gradio, inputs=[file_input, query_input, topk_input], outputs=images_output) | |
app.launch(share=True) | |