File size: 4,024 Bytes
94b55f0
 
602d806
 
fc4a494
 
 
 
 
 
3649694
602d806
 
 
 
 
 
fc4a494
602d806
 
 
 
 
 
 
 
 
 
fc4a494
602d806
fc4a494
 
 
 
 
 
 
 
602d806
 
fc4a494
602d806
 
fc4a494
 
 
602d806
 
 
654c2e1
602d806
 
 
 
 
 
 
 
 
 
fc4a494
602d806
d5db6a5
94b55f0
3649694
fc4a494
3649694
fc4a494
602d806
fc4a494
 
602d806
fc4a494
602d806
 
fc4a494
 
 
602d806
fc4a494
602d806
fc4a494
 
 
 
 
 
 
 
 
 
 
602d806
fc4a494
 
 
 
602d806
fc4a494
 
 
602d806
fc4a494
 
602d806
 
fc4a494
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os

import gradio as gr
import torch
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
    process_images,
    process_queries,
)
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor


def search(query: str, ds, images, k):
    qs = []
    with torch.no_grad():
        batch_query = process_queries(processor, [query], mock_image)
        batch_query = {k: v.to(device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)

    top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]

    results = []
    for idx in top_k_indices:
        results.append((images[idx], f"Page {idx}"))

    return results


def index(files, ds):
    """Example script to run inference with ColPali"""
    images = []
    for f in files:
        images.extend(convert_from_path(f))

    if len(images) >= 150:
        raise gr.Error("The number of images in the dataset should be less than 150.")

    # run inference - docs
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    return f"Uploaded and converted {len(images)} pages", ds, images

cache_dir = os.path.join(os.getcwd(), "data/", "model_cache/")
# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
    "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token = token, cache_dir=cache_dir
).eval()

model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, cache_dir=cache_dir, token = token)

device = model.device

mock_image = Image.new("RGB", (448, 448), (255, 255, 255))

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“š")
    gr.Markdown("""Demo to test ColPali on PDF documents. The inference code is based on the [ViDoRe benchmark](https://github.com/illuin-tech/vidore-benchmark).

    ColPali is model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).

    This demo allows you to upload PDF files and search for the most relevant pages based on your query.
    """)
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1️⃣ Upload PDFs")
            file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")

            convert_button = gr.Button("πŸ”„ Convert and upload")
            message = gr.Textbox("Files not yet uploaded", label="Status")
            embeds = gr.State(value=[])
            imgs = gr.State(value=[])

        with gr.Column(scale=3):
            gr.Markdown("## 2️⃣ Search")
            query = gr.Textbox(placeholder="Enter your query here", label="Query")
            k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=3)

    # Define the actions
    search_button = gr.Button("πŸ” Search", variant="primary")
    output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)

    convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
    search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True, server_name="0.0.0.0", server_port=7861)