ahmed-masry commited on
Commit
d0976af
1 Parent(s): a6b6457

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +58 -28
README.md CHANGED
@@ -24,65 +24,95 @@ pip install . -e
24
  Then, you can run the following inference code:
25
 
26
  ```python
 
 
 
27
  import torch
28
- import typer
29
  from torch.utils.data import DataLoader
30
  from tqdm import tqdm
31
- from transformers import AutoProcessor
32
- from PIL import Image
33
 
34
- from colpali_engine.models.paligemma_colbert_architecture import ColPali
35
- from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
36
- from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
37
- from colpali_engine.utils.image_from_page_utils import load_from_dataset
 
38
 
 
 
 
 
39
 
40
- def main() -> None:
41
- """Example script to run inference with ColPali"""
 
 
 
42
 
43
  # Load model
44
- model_name = "vidore/colpali"
45
- model = ColPali.from_pretrained("vidore/colpaligemma-3b-mix-448-base", torch_dtype=torch.bfloat16, device_map="cuda").eval()
46
- model.load_adapter(model_name)
47
- processor = AutoProcessor.from_pretrained(model_name)
 
 
 
 
48
 
49
- # select images -> load_from_pdf(<pdf_path>), load_from_image_urls(["<url_1>"]), load_from_dataset(<path>)
50
- images = load_from_dataset("vidore/docvqa_test_subsampled")
51
- queries = ["From which university does James V. Fiorca come ?", "Who is the japanese prime minister?"]
52
 
53
- # run inference - docs
 
 
 
 
 
 
 
 
 
 
54
  dataloader = DataLoader(
55
- images,
56
  batch_size=4,
57
  shuffle=False,
58
- collate_fn=lambda x: process_images(processor, x),
59
  )
60
- ds = []
61
  for batch_doc in tqdm(dataloader):
62
  with torch.no_grad():
63
  batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
64
  embeddings_doc = model(**batch_doc)
65
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
66
 
67
- # run inference - queries
68
  dataloader = DataLoader(
69
- queries,
70
  batch_size=4,
71
  shuffle=False,
72
- collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
73
  )
74
 
75
- qs = []
76
  for batch_query in dataloader:
77
  with torch.no_grad():
78
  batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
79
  embeddings_query = model(**batch_query)
80
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
81
 
82
- # run evaluation
83
- retriever_evaluator = CustomEvaluator(is_multi_vector=True)
84
- scores = retriever_evaluator.evaluate(qs, ds)
85
- print(scores.argmax(axis=1))
 
 
 
 
 
 
 
 
86
 
87
 
88
  if __name__ == "__main__":
 
24
  Then, you can run the following inference code:
25
 
26
  ```python
27
+ import pprint
28
+ from typing import List, cast
29
+
30
  import torch
31
+ from datasets import Dataset, load_dataset
32
  from torch.utils.data import DataLoader
33
  from tqdm import tqdm
 
 
34
 
35
+ from colpali_engine.models import ColFlor
36
+ from colpali_engine.models import ColFlorProcessor
37
+ from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
38
+ from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
39
+
40
 
41
+ def main():
42
+ """
43
+ Example script to run inference with ColFlor.
44
+ """
45
 
46
+ device = get_torch_device("auto")
47
+ print(f"Device used: {device}")
48
+
49
+ # Model name
50
+ model_name = "ahmed-masry/ColFlor"
51
 
52
  # Load model
53
+ model = ColFlor.from_pretrained(
54
+ model_name,
55
+ #torch_dtype=torch.bfloat16,
56
+ device_map=device,
57
+ ).eval()
58
+
59
+ # Load processor
60
+ processor = cast(ColFlorProcessor, ColFlorProcessor.from_pretrained(model_name))
61
 
62
+ if not isinstance(processor, BaseVisualRetrieverProcessor):
63
+ raise ValueError("Processor should be a BaseVisualRetrieverProcessor")
 
64
 
65
+ # NOTE: Only the first 16 images are used for demonstration purposes
66
+ dataset = cast(Dataset, load_dataset("vidore/docvqa_test_subsampled", split="test[:16]"))
67
+ images = dataset["image"]
68
+
69
+ # Select a few queries for demonstration purposes
70
+ query_indices = [12, 15]
71
+ queries = [dataset[idx]["query"] for idx in query_indices]
72
+ print("Selected queries:")
73
+ pprint.pprint(dict(zip(query_indices, queries)))
74
+
75
+ # Run inference - docs
76
  dataloader = DataLoader(
77
+ dataset=ListDataset[str](images),
78
  batch_size=4,
79
  shuffle=False,
80
+ collate_fn=lambda x: processor.process_images(x),
81
  )
82
+ ds: List[torch.Tensor] = []
83
  for batch_doc in tqdm(dataloader):
84
  with torch.no_grad():
85
  batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
86
  embeddings_doc = model(**batch_doc)
87
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
88
 
89
+ # Run inference - queries
90
  dataloader = DataLoader(
91
+ dataset=ListDataset[str](queries),
92
  batch_size=4,
93
  shuffle=False,
94
+ collate_fn=lambda x: processor.process_queries(x),
95
  )
96
 
97
+ qs: List[torch.Tensor] = []
98
  for batch_query in dataloader:
99
  with torch.no_grad():
100
  batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
101
  embeddings_query = model(**batch_query)
102
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
103
 
104
+ # Run scoring
105
+ scores = processor.score(qs, ds).cpu().numpy()
106
+ idx_top_1 = scores.argmax(axis=1)
107
+ print("Indices of the top-1 retrieved documents for each query:", idx_top_1)
108
+
109
+ # Sanity check
110
+ if idx_top_1.tolist() == query_indices:
111
+ print("The top-1 retrieved documents are correct.")
112
+ else:
113
+ print("The top-1 retrieved documents are incorrect.")
114
+
115
+ return
116
 
117
 
118
  if __name__ == "__main__":