manu commited on
Commit
3649694
β€’
1 Parent(s): d40ecad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -17
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import os
2
 
3
  import gradio as gr
4
- from pdf2image import convert_from_path
5
-
6
  import torch
 
7
  from PIL import Image
8
  from torch.utils.data import DataLoader
9
  from tqdm import tqdm
@@ -60,7 +59,7 @@ def search(query: str, ds, images):
60
  retriever_evaluator = CustomEvaluator(is_multi_vector=True)
61
  scores = retriever_evaluator.evaluate(qs, ds)
62
  best_page = int(scores.argmax(axis=1).item())
63
- return f"The most relevant page is {best_page}", images[best_page]
64
 
65
 
66
  def index(file, ds):
@@ -84,18 +83,20 @@ def index(file, ds):
84
  return f"Uploaded and converted {len(images)} pages", ds, images
85
 
86
 
87
- COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
88
  # Load model
89
  model_name = "coldoc/colpali-3b-mix-448"
90
  token = os.environ.get("HF_TOKEN")
91
- model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token=token).eval()
 
 
92
  model.load_adapter(model_name)
93
  processor = AutoProcessor.from_pretrained(model_name, token=token)
94
  device = model.device
95
  mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
96
 
97
  with gr.Blocks() as demo:
98
- gr.Markdown("# PDF to πŸ€— Dataset")
99
  gr.Markdown("## 1️⃣ Upload PDFs")
100
  file = gr.File(file_types=["pdf"], file_count="multiple")
101
 
@@ -103,14 +104,10 @@ with gr.Blocks() as demo:
103
  convert_button = gr.Button("πŸ”„ Convert and upload")
104
  message = gr.Textbox("Files not yet uploaded")
105
  embeds = gr.State(value=[])
106
- imgs = gr.State()
107
 
108
  # Define the actions
109
- convert_button.click(
110
- index,
111
- inputs=[file, embeds],
112
- outputs=[message, embeds, imgs]
113
- )
114
 
115
  gr.Markdown("## 3️⃣ Search")
116
  query = gr.Textbox(placeholder="Enter your query here")
@@ -118,11 +115,8 @@ with gr.Blocks() as demo:
118
  message2 = gr.Textbox("Query not yet set")
119
  output_img = gr.Image()
120
 
121
- search_button.click(
122
- search, inputs=[query, embeds, imgs],
123
- outputs=[message2, output_img]
124
- )
125
 
126
 
127
  if __name__ == "__main__":
128
- demo.queue(max_size=10).launch(debug=True)
 
1
  import os
2
 
3
  import gradio as gr
 
 
4
  import torch
5
+ from pdf2image import convert_from_path
6
  from PIL import Image
7
  from torch.utils.data import DataLoader
8
  from tqdm import tqdm
 
59
  retriever_evaluator = CustomEvaluator(is_multi_vector=True)
60
  scores = retriever_evaluator.evaluate(qs, ds)
61
  best_page = int(scores.argmax(axis=1).item())
62
+ return f"The most relevant page is {best_page}", images[best_page]
63
 
64
 
65
  def index(file, ds):
 
83
  return f"Uploaded and converted {len(images)} pages", ds, images
84
 
85
 
86
+ COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
87
  # Load model
88
  model_name = "coldoc/colpali-3b-mix-448"
89
  token = os.environ.get("HF_TOKEN")
90
+ model = ColPali.from_pretrained(
91
+ "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token=token
92
+ ).eval()
93
  model.load_adapter(model_name)
94
  processor = AutoProcessor.from_pretrained(model_name, token=token)
95
  device = model.device
96
  mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
97
 
98
  with gr.Blocks() as demo:
99
+ gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“šπŸ”")
100
  gr.Markdown("## 1️⃣ Upload PDFs")
101
  file = gr.File(file_types=["pdf"], file_count="multiple")
102
 
 
104
  convert_button = gr.Button("πŸ”„ Convert and upload")
105
  message = gr.Textbox("Files not yet uploaded")
106
  embeds = gr.State(value=[])
107
+ imgs = gr.State(value=[])
108
 
109
  # Define the actions
110
+ convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
 
 
 
 
111
 
112
  gr.Markdown("## 3️⃣ Search")
113
  query = gr.Textbox(placeholder="Enter your query here")
 
115
  message2 = gr.Textbox("Query not yet set")
116
  output_img = gr.Image()
117
 
118
+ search_button.click(search, inputs=[query, embeds, imgs], outputs=[message2, output_img])
 
 
 
119
 
120
 
121
  if __name__ == "__main__":
122
+ demo.queue(max_size=10).launch(debug=True)