ahmed-masry
commited on
Commit
•
d0976af
1
Parent(s):
a6b6457
Update README.md
Browse files
README.md
CHANGED
@@ -24,65 +24,95 @@ pip install . -e
|
|
24 |
Then, you can run the following inference code:
|
25 |
|
26 |
```python
|
|
|
|
|
|
|
27 |
import torch
|
28 |
-
import
|
29 |
from torch.utils.data import DataLoader
|
30 |
from tqdm import tqdm
|
31 |
-
from transformers import AutoProcessor
|
32 |
-
from PIL import Image
|
33 |
|
34 |
-
from colpali_engine.models
|
35 |
-
from colpali_engine.
|
36 |
-
from colpali_engine.utils.
|
37 |
-
from colpali_engine.utils.
|
|
|
38 |
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
|
41 |
-
"
|
|
|
|
|
|
|
42 |
|
43 |
# Load model
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
queries = ["From which university does James V. Fiorca come ?", "Who is the japanese prime minister?"]
|
52 |
|
53 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
dataloader = DataLoader(
|
55 |
-
images,
|
56 |
batch_size=4,
|
57 |
shuffle=False,
|
58 |
-
collate_fn=lambda x: process_images(
|
59 |
)
|
60 |
-
ds = []
|
61 |
for batch_doc in tqdm(dataloader):
|
62 |
with torch.no_grad():
|
63 |
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
64 |
embeddings_doc = model(**batch_doc)
|
65 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
66 |
|
67 |
-
#
|
68 |
dataloader = DataLoader(
|
69 |
-
queries,
|
70 |
batch_size=4,
|
71 |
shuffle=False,
|
72 |
-
collate_fn=lambda x: process_queries(
|
73 |
)
|
74 |
|
75 |
-
qs = []
|
76 |
for batch_query in dataloader:
|
77 |
with torch.no_grad():
|
78 |
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
79 |
embeddings_query = model(**batch_query)
|
80 |
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
81 |
|
82 |
-
#
|
83 |
-
|
84 |
-
|
85 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
if __name__ == "__main__":
|
|
|
24 |
Then, you can run the following inference code:
|
25 |
|
26 |
```python
|
27 |
+
import pprint
|
28 |
+
from typing import List, cast
|
29 |
+
|
30 |
import torch
|
31 |
+
from datasets import Dataset, load_dataset
|
32 |
from torch.utils.data import DataLoader
|
33 |
from tqdm import tqdm
|
|
|
|
|
34 |
|
35 |
+
from colpali_engine.models import ColFlor
|
36 |
+
from colpali_engine.models import ColFlorProcessor
|
37 |
+
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
|
38 |
+
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
39 |
+
|
40 |
|
41 |
+
def main():
|
42 |
+
"""
|
43 |
+
Example script to run inference with ColFlor.
|
44 |
+
"""
|
45 |
|
46 |
+
device = get_torch_device("auto")
|
47 |
+
print(f"Device used: {device}")
|
48 |
+
|
49 |
+
# Model name
|
50 |
+
model_name = "ahmed-masry/ColFlor"
|
51 |
|
52 |
# Load model
|
53 |
+
model = ColFlor.from_pretrained(
|
54 |
+
model_name,
|
55 |
+
#torch_dtype=torch.bfloat16,
|
56 |
+
device_map=device,
|
57 |
+
).eval()
|
58 |
+
|
59 |
+
# Load processor
|
60 |
+
processor = cast(ColFlorProcessor, ColFlorProcessor.from_pretrained(model_name))
|
61 |
|
62 |
+
if not isinstance(processor, BaseVisualRetrieverProcessor):
|
63 |
+
raise ValueError("Processor should be a BaseVisualRetrieverProcessor")
|
|
|
64 |
|
65 |
+
# NOTE: Only the first 16 images are used for demonstration purposes
|
66 |
+
dataset = cast(Dataset, load_dataset("vidore/docvqa_test_subsampled", split="test[:16]"))
|
67 |
+
images = dataset["image"]
|
68 |
+
|
69 |
+
# Select a few queries for demonstration purposes
|
70 |
+
query_indices = [12, 15]
|
71 |
+
queries = [dataset[idx]["query"] for idx in query_indices]
|
72 |
+
print("Selected queries:")
|
73 |
+
pprint.pprint(dict(zip(query_indices, queries)))
|
74 |
+
|
75 |
+
# Run inference - docs
|
76 |
dataloader = DataLoader(
|
77 |
+
dataset=ListDataset[str](images),
|
78 |
batch_size=4,
|
79 |
shuffle=False,
|
80 |
+
collate_fn=lambda x: processor.process_images(x),
|
81 |
)
|
82 |
+
ds: List[torch.Tensor] = []
|
83 |
for batch_doc in tqdm(dataloader):
|
84 |
with torch.no_grad():
|
85 |
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
86 |
embeddings_doc = model(**batch_doc)
|
87 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
88 |
|
89 |
+
# Run inference - queries
|
90 |
dataloader = DataLoader(
|
91 |
+
dataset=ListDataset[str](queries),
|
92 |
batch_size=4,
|
93 |
shuffle=False,
|
94 |
+
collate_fn=lambda x: processor.process_queries(x),
|
95 |
)
|
96 |
|
97 |
+
qs: List[torch.Tensor] = []
|
98 |
for batch_query in dataloader:
|
99 |
with torch.no_grad():
|
100 |
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
101 |
embeddings_query = model(**batch_query)
|
102 |
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
103 |
|
104 |
+
# Run scoring
|
105 |
+
scores = processor.score(qs, ds).cpu().numpy()
|
106 |
+
idx_top_1 = scores.argmax(axis=1)
|
107 |
+
print("Indices of the top-1 retrieved documents for each query:", idx_top_1)
|
108 |
+
|
109 |
+
# Sanity check
|
110 |
+
if idx_top_1.tolist() == query_indices:
|
111 |
+
print("The top-1 retrieved documents are correct.")
|
112 |
+
else:
|
113 |
+
print("The top-1 retrieved documents are incorrect.")
|
114 |
+
|
115 |
+
return
|
116 |
|
117 |
|
118 |
if __name__ == "__main__":
|