gabrielaltay commited on
Commit
f40eab1
1 Parent(s): 57ce204

initial commit

Browse files
Files changed (2) hide show
  1. app.py +207 -0
  2. requirements.txt +107 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+
3
+ from colpali_engine.models.paligemma_colbert_architecture import ColPali
4
+ from colpali_engine.utils.colpali_processing_utils import process_images
5
+ from colpali_engine.utils.colpali_processing_utils import process_queries
6
+ import google.generativeai as genai
7
+ import numpy as np
8
+ import pdf2image
9
+ from PIL import Image
10
+ import requests
11
+ import streamlit as st
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from transformers import AutoProcessor
15
+
16
+
17
+ SS = st.session_state
18
+
19
+
20
+ def initialize_session_state():
21
+ keys = [
22
+ "colpali_model",
23
+ "page_images",
24
+ "retrieved_page_images",
25
+ "response",
26
+ ]
27
+ for key in keys:
28
+ if key not in SS:
29
+ SS[key] = None
30
+
31
+
32
+ def get_device():
33
+ if torch.cuda.is_available():
34
+ device = torch.device("cuda")
35
+ elif torch.backends.mps.is_available():
36
+ device = torch.device("mps")
37
+ else:
38
+ device = torch.device("cpu")
39
+ return device
40
+
41
+
42
+ def get_dtype(device: torch.device):
43
+ if device == torch.device("cuda"):
44
+ dtype = torch.bfloat16
45
+ elif device == torch.device("mps"):
46
+ dtype = torch.float32
47
+ else:
48
+ dtype = torch.float32
49
+ return dtype
50
+
51
+
52
+ def load_colpali_model():
53
+ paligemma_model_name = "google/paligemma-3b-mix-448"
54
+ colpali_model_name = "vidore/colpali"
55
+ device = get_device()
56
+ dtype = get_dtype(device)
57
+
58
+ model = ColPali.from_pretrained(paligemma_model_name, torch_dtype=dtype).eval()
59
+ model.load_adapter(colpali_model_name)
60
+ model.to(device)
61
+ processor = AutoProcessor.from_pretrained(colpali_model_name)
62
+ return model, processor
63
+
64
+
65
+ def embed_page_images(model, processor, page_images, batch_size=2):
66
+ dataloader = DataLoader(
67
+ page_images,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ collate_fn=lambda x: process_images(processor, x),
71
+ )
72
+ page_embeddings = []
73
+ for batch in dataloader:
74
+ with torch.no_grad():
75
+ batch = {k: v.to(model.device) for k, v in batch.items()}
76
+ embeddings = model(**batch)
77
+ page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
78
+ return np.array(page_embeddings)
79
+
80
+
81
+ def embed_query_texts(model, processor, query_texts, batch_size=1):
82
+ # 448 is from the paligemma resolution we loaded
83
+ dummy_image = Image.new("RGB", (448, 448), (255, 255, 255))
84
+ dataloader = DataLoader(
85
+ query_texts,
86
+ batch_size=batch_size,
87
+ shuffle=False,
88
+ collate_fn=lambda x: process_queries(processor, x, dummy_image),
89
+ )
90
+ query_embeddings = []
91
+ for batch in dataloader:
92
+ with torch.no_grad():
93
+ batch = {k: v.to(model.device) for k, v in batch.items()}
94
+ embeddings = model(**batch)
95
+ query_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
96
+ return np.array(query_embeddings)[0]
97
+
98
+
99
+
100
+ def get_pdf_page_images_from_bytes(
101
+ pdf_bytes: bytes,
102
+ use_tmp_dir=False,
103
+ ):
104
+ if use_tmp_dir:
105
+ with tempfile.TemporaryDirectory() as tmp_path:
106
+ page_images = pdf2image.convert_from_bytes(pdf_bytes, output_folder=tmp_path)
107
+ else:
108
+ page_images = pdf2image.convert_from_bytes(pdf_bytes)
109
+ return page_images
110
+
111
+
112
+ def get_pdf_bytes_from_url(url: str) -> bytes | None:
113
+ response = requests.get(url)
114
+ if response.status_code == 200:
115
+ return response.content
116
+ else:
117
+ print(f"failed to fetch {url}")
118
+ print(response)
119
+ return None
120
+
121
+
122
+ def display_pages(page_images, key):
123
+ n_cols = st.slider("ncol", min_value=1, max_value=8, value=4, step=1, key=key)
124
+ cols = st.columns(n_cols)
125
+ for ii_page, page_image in enumerate(page_images):
126
+ ii_col = ii_page % n_cols
127
+ with cols[ii_col]:
128
+ st.image(page_image)
129
+
130
+
131
+ initialize_session_state()
132
+
133
+
134
+ if SS["colpali_model"] is None:
135
+ SS["colpali_model"], SS["processor"] = load_colpali_model()
136
+
137
+
138
+ with st.sidebar:
139
+ url = st.text_input("arxiv url", "https://arxiv.org/pdf/2112.01488.pdf")
140
+
141
+ if st.button("load paper"):
142
+ pdf_bytes = get_pdf_bytes_from_url(url)
143
+ SS["page_images"] = get_pdf_page_images_from_bytes(pdf_bytes)
144
+
145
+
146
+ if st.button("embed pages"):
147
+ SS["page_embeddings"] = embed_page_images(
148
+ SS["colpali_model"],
149
+ SS["processor"],
150
+ SS["page_images"],
151
+ )
152
+
153
+
154
+ with st.container(border=True):
155
+ query = st.text_area("query")
156
+ top_k = st.slider("num pages to retrieve", min_value=1, max_value=8, value=3, step=1)
157
+ if st.button("answer query"):
158
+ SS["query_embeddings"] = embed_query_texts(
159
+ SS["colpali_model"],
160
+ SS["processor"],
161
+ [query],
162
+ )
163
+
164
+ page_query_scores = []
165
+ for ipage in range(len(SS["page_embeddings"])):
166
+ # for every query token find the max_sim with every page patch
167
+ patch_query_scores = np.dot(
168
+ SS['page_embeddings'][ipage],
169
+ SS["query_embeddings"].T,
170
+ )
171
+ max_sim_score = patch_query_scores.max(axis=0).sum()
172
+ page_query_scores.append(max_sim_score)
173
+
174
+ page_query_scores = np.array(page_query_scores)
175
+ i_ranked_pages = np.argsort(-page_query_scores)
176
+
177
+ page_images = []
178
+ for ii in range(top_k):
179
+ page_images.append(SS["page_images"][i_ranked_pages[ii]])
180
+ SS["retrieved_page_images"] = page_images
181
+
182
+
183
+ prompt = [
184
+ query +
185
+ " Think through your answer step by step. "
186
+ "Support your answer with descriptions of the images. "
187
+ "Do not infer information that is not in the images.",
188
+ ] + page_images
189
+
190
+ genai.configure(api_key=st.secrets["google_genai_api_key"])
191
+ # gen_model = genai.GenerativeModel(model_name="gemini-1.5-flash")
192
+ gen_model = genai.GenerativeModel(model_name="gemini-1.5-pro")
193
+ response = gen_model.generate_content(prompt)
194
+ text = response.candidates[0].content.parts[0].text
195
+ SS["response"] = text
196
+
197
+
198
+ if SS["response"] is not None:
199
+ st.write(SS["response"])
200
+ st.header("Retrieved Pages")
201
+ display_pages(SS["retrieved_page_images"], "retrieved_pages")
202
+
203
+
204
+
205
+ if SS["page_images"] is not None:
206
+ st.header("All PDF Pages")
207
+ display_pages(SS["page_images"], "all_pages")
requirements.txt ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.32.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.7.0
6
+ async-timeout==4.0.3
7
+ attrs==23.2.0
8
+ black==24.4.2
9
+ blinker==1.8.2
10
+ cachetools==5.4.0
11
+ certifi==2024.7.4
12
+ charset-normalizer==3.3.2
13
+ click==8.1.7
14
+ colpali_engine @ git+https://github.com/illuin-tech/colpali@8b01824546c62e46383ce26b439d9bfc6468f763
15
+ datasets==2.20.0
16
+ dill==0.3.8
17
+ eval_type_backport==0.2.0
18
+ filelock==3.15.4
19
+ frozenlist==1.4.1
20
+ fsspec==2024.5.0
21
+ gitdb==4.0.11
22
+ GitPython==3.1.43
23
+ google-ai-generativelanguage==0.6.6
24
+ google-api-core==2.19.1
25
+ google-api-python-client==2.137.0
26
+ google-auth==2.32.0
27
+ google-auth-httplib2==0.2.0
28
+ google-generativeai==0.7.2
29
+ googleapis-common-protos==1.63.2
30
+ GPUtil==1.4.0
31
+ grpcio==1.65.1
32
+ grpcio-status==1.62.2
33
+ httplib2==0.22.0
34
+ huggingface-hub==0.24.0
35
+ idna==3.7
36
+ importlib_metadata==7.2.1
37
+ Jinja2==3.1.4
38
+ joblib==1.4.2
39
+ jsonschema==4.23.0
40
+ jsonschema-specifications==2023.12.1
41
+ markdown-it-py==3.0.0
42
+ MarkupSafe==2.1.5
43
+ mdurl==0.1.2
44
+ mpmath==1.3.0
45
+ mteb==1.12.85
46
+ multidict==6.0.5
47
+ multiprocess==0.70.16
48
+ mypy-extensions==1.0.0
49
+ networkx==3.3
50
+ numpy==1.26.4
51
+ packaging==23.2
52
+ pandas==2.2.2
53
+ pathspec==0.12.1
54
+ pdf2image==1.17.0
55
+ peft==0.11.1
56
+ pillow==10.4.0
57
+ platformdirs==4.2.2
58
+ polars==1.2.1
59
+ proto-plus==1.24.0
60
+ protobuf==4.25.3
61
+ psutil==6.0.0
62
+ pyarrow==17.0.0
63
+ pyarrow-hotfix==0.6
64
+ pyasn1==0.6.0
65
+ pyasn1_modules==0.4.0
66
+ pydantic==2.8.2
67
+ pydantic_core==2.20.1
68
+ pydeck==0.9.1
69
+ Pygments==2.18.0
70
+ pyparsing==3.1.2
71
+ python-dateutil==2.9.0.post0
72
+ pytrec_eval-terrier==0.5.6
73
+ pytz==2024.1
74
+ PyYAML==6.0.1
75
+ referencing==0.35.1
76
+ regex==2024.5.15
77
+ requests==2.32.3
78
+ rich==13.7.1
79
+ rpds-py==0.19.0
80
+ rsa==4.9
81
+ safetensors==0.4.3
82
+ scikit-learn==1.5.1
83
+ scipy==1.14.0
84
+ sentence-transformers==3.0.1
85
+ six==1.16.0
86
+ smmap==5.0.1
87
+ streamlit==1.31.1
88
+ sympy==1.13.1
89
+ tenacity==8.5.0
90
+ threadpoolctl==3.5.0
91
+ tokenizers==0.19.1
92
+ toml==0.10.2
93
+ tomli==2.0.1
94
+ toolz==0.12.1
95
+ torch==2.3.1
96
+ tornado==6.4.1
97
+ tqdm==4.66.4
98
+ transformers==4.42.4
99
+ typing_extensions==4.12.2
100
+ tzdata==2024.1
101
+ tzlocal==5.2
102
+ uritemplate==4.1.1
103
+ urllib3==2.2.2
104
+ validators==0.33.0
105
+ xxhash==3.4.1
106
+ yarl==1.9.4
107
+ zipp==3.19.2