Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import numpy as np | |
import pypdfium2 as pdfium | |
import torch | |
import tqdm | |
from model import encode_images, encode_queries | |
from PIL import Image | |
from sqlitedict import SqliteDict | |
from voyager import Index, Space | |
def iter_batch( | |
X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = "" | |
) -> list: | |
"""Iterate over a list of elements by batch.""" | |
batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)] | |
if tqdm_bar: | |
for batch in tqdm.tqdm( | |
iterable=batchs, | |
position=0, | |
total=1 + len(X) // batch_size, | |
desc=desc, | |
): | |
yield batch | |
else: | |
yield from batchs | |
class Voyager: | |
"""Voyager index. The Voyager index is a fast and efficient index for approximate nearest neighbor search. | |
Parameters | |
---------- | |
name | |
The name of the collection. | |
override | |
Whether to override the collection if it already exists. | |
embedding_size | |
The number of dimensions of the embeddings. | |
M | |
The number of subquantizers. | |
ef_construction | |
The number of candidates to evaluate during the construction of the index. | |
ef_search | |
The number of candidates to evaluate during the search. | |
""" | |
def __init__( | |
self, | |
index_folder: str = "indexes", | |
index_name: str = "base_collection", | |
override: bool = False, | |
embedding_size: int = 128, | |
M: int = 64, | |
ef_construction: int = 200, | |
ef_search: int = 200, | |
) -> None: | |
self.ef_search = ef_search | |
if not os.path.exists(path=index_folder): | |
os.makedirs(name=index_folder) | |
self.index_path = os.path.join(index_folder, f"{index_name}.voyager") | |
self.page_ids_to_data_path = os.path.join( | |
index_folder, f"{index_name}_page_ids_to_data.sqlite" | |
) | |
self.index = self._create_collection( | |
index_path=self.index_path, | |
embedding_size=embedding_size, | |
M=M, | |
ef_constructions=ef_construction, | |
override=override, | |
) | |
def _load_page_ids_to_data(self) -> SqliteDict: | |
"""Load the SQLite database that maps document IDs to images.""" | |
return SqliteDict(self.page_ids_to_data_path, outer_stack=False) | |
def _create_collection( | |
self, | |
index_path: str, | |
embedding_size: int, | |
M: int, | |
ef_constructions: int, | |
override: bool, | |
) -> None: | |
"""Create a new Voyager collection. | |
Parameters | |
---------- | |
index_path | |
The path to the index. | |
embedding_size | |
The size of the embeddings. | |
M | |
The number of subquantizers. | |
ef_constructions | |
The number of candidates to evaluate during the construction of the index. | |
override | |
Whether to override the collection if it already exists. | |
""" | |
if os.path.exists(path=index_path) and not override: | |
return Index.load(index_path) | |
if os.path.exists(path=index_path): | |
os.remove(index_path) | |
# Create the Voyager index | |
index = Index( | |
Space.Cosine, | |
num_dimensions=embedding_size, | |
M=M, | |
ef_construction=ef_constructions, | |
) | |
index.save(index_path) | |
if override and os.path.exists(path=self.page_ids_to_data_path): | |
os.remove(path=self.page_ids_to_data_path) | |
# Create the SQLite databases | |
page_ids_to_data = self._load_page_ids_to_data() | |
page_ids_to_data.close() | |
return index | |
def add_documents( | |
self, | |
paths: str | list[str], | |
batch_size: int = 1, | |
) -> None: | |
"""Add documents to the index. Note that batch_size means the number of pages to encode at once, not documents.""" | |
if isinstance(paths, str): | |
paths = [paths] | |
page_ids_to_data = self._load_page_ids_to_data() | |
images = [] | |
num_pages = [] | |
for path in paths: | |
if path.lower().endswith(".pdf"): | |
pdf = pdfium.PdfDocument(path) | |
n_pages = len(pdf) | |
num_pages.append(n_pages) | |
for page_number in range(n_pages): | |
page = pdf.get_page(page_number) | |
pil_image = page.render( | |
scale=1, | |
rotation=0, | |
) | |
pil_image = pil_image.to_pil() | |
images.append(pil_image) | |
pdf.close() | |
else: | |
pil_image = Image.open(path) | |
images.append(pil_image) | |
num_pages.append(1) | |
embeddings = [] | |
for batch in iter_batch( | |
X=images, batch_size=batch_size, desc=f"Encoding pages (bs={batch_size})" | |
): | |
embeddings.extend(encode_images(batch)) | |
embeddings_ids = self.index.add_items(embeddings) | |
current_index = 0 | |
for i, path in enumerate(paths): | |
for page_number in range(num_pages[i]): | |
page_ids_to_data[embeddings_ids[current_index]] = { | |
"path": path, | |
"image": images[current_index], | |
"page_number": page_number, | |
} | |
current_index += 1 | |
page_ids_to_data.commit() | |
self.index.save(self.index_path) | |
return self | |
def __call__( | |
self, | |
queries: np.ndarray | torch.Tensor, | |
k: int = 10, | |
) -> dict: | |
"""Query the index for the nearest neighbors of the queries embeddings. | |
Parameters | |
---------- | |
queries_embeddings | |
The queries embeddings. | |
k | |
The number of nearest neighbors to return. | |
""" | |
queries_embeddings = encode_queries(queries) | |
page_ids_to_data = self._load_page_ids_to_data() | |
k = min(k, len(page_ids_to_data)) | |
n_queries = len(queries_embeddings) | |
indices, distances = self.index.query( | |
queries_embeddings, k, query_ef=self.ef_search | |
) | |
if len(indices) == 0: | |
raise ValueError("Index is empty, add documents before querying.") | |
documents = [ | |
[page_ids_to_data[str(indice)] for indice in query_indices] | |
for query_indices in indices | |
] | |
page_ids_to_data.close() | |
return { | |
"documents": documents, | |
"distances": distances.reshape(n_queries, -1, k), | |
} | |