File size: 6,615 Bytes
349b5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os

import numpy as np
import pypdfium2 as pdfium
import torch
import tqdm
from model import encode_images, encode_queries
from PIL import Image
from sqlitedict import SqliteDict
from voyager import Index, Space


def iter_batch(
    X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = ""
) -> list:
    """Iterate over a list of elements by batch."""
    batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)]

    if tqdm_bar:
        for batch in tqdm.tqdm(
            iterable=batchs,
            position=0,
            total=1 + len(X) // batch_size,
            desc=desc,
        ):
            yield batch
    else:
        yield from batchs


class Voyager:
    """Voyager index. The Voyager index is a fast and efficient index for approximate nearest neighbor search.

    Parameters
    ----------
    name
        The name of the collection.
    override
        Whether to override the collection if it already exists.
    embedding_size
        The number of dimensions of the embeddings.
    M
        The number of subquantizers.
    ef_construction
        The number of candidates to evaluate during the construction of the index.
    ef_search
        The number of candidates to evaluate during the search.
    """

    def __init__(
        self,
        index_folder: str = "indexes",
        index_name: str = "base_collection",
        override: bool = False,
        embedding_size: int = 128,
        M: int = 64,
        ef_construction: int = 200,
        ef_search: int = 200,
    ) -> None:
        self.ef_search = ef_search

        if not os.path.exists(path=index_folder):
            os.makedirs(name=index_folder)

        self.index_path = os.path.join(index_folder, f"{index_name}.voyager")
        self.page_ids_to_data_path = os.path.join(
            index_folder, f"{index_name}_page_ids_to_data.sqlite"
        )

        self.index = self._create_collection(
            index_path=self.index_path,
            embedding_size=embedding_size,
            M=M,
            ef_constructions=ef_construction,
            override=override,
        )

    def _load_page_ids_to_data(self) -> SqliteDict:
        """Load the SQLite database that maps document IDs to images."""
        return SqliteDict(self.page_ids_to_data_path, outer_stack=False)

    def _create_collection(
        self,
        index_path: str,
        embedding_size: int,
        M: int,
        ef_constructions: int,
        override: bool,
    ) -> None:
        """Create a new Voyager collection.

        Parameters
        ----------
        index_path
            The path to the index.
        embedding_size
            The size of the embeddings.
        M
            The number of subquantizers.
        ef_constructions
            The number of candidates to evaluate during the construction of the index.
        override
            Whether to override the collection if it already exists.

        """
        if os.path.exists(path=index_path) and not override:
            return Index.load(index_path)

        if os.path.exists(path=index_path):
            os.remove(index_path)

        # Create the Voyager index
        index = Index(
            Space.Cosine,
            num_dimensions=embedding_size,
            M=M,
            ef_construction=ef_constructions,
        )

        index.save(index_path)

        if override and os.path.exists(path=self.page_ids_to_data_path):
            os.remove(path=self.page_ids_to_data_path)

        # Create the SQLite databases
        page_ids_to_data = self._load_page_ids_to_data()
        page_ids_to_data.close()
        return index

    def add_documents(
        self,
        paths: str | list[str],
        batch_size: int = 1,
    ) -> None:
        """Add documents to the index. Note that batch_size means the number of pages to encode at once, not documents."""
        if isinstance(paths, str):
            paths = [paths]

        page_ids_to_data = self._load_page_ids_to_data()

        images = []
        num_pages = []

        for path in paths:
            if path.lower().endswith(".pdf"):
                pdf = pdfium.PdfDocument(path)
                n_pages = len(pdf)
                num_pages.append(n_pages)
                for page_number in range(n_pages):
                    page = pdf.get_page(page_number)
                    pil_image = page.render(
                        scale=1,
                        rotation=0,
                    )
                    pil_image = pil_image.to_pil()
                    images.append(pil_image)
                pdf.close()
            else:
                pil_image = Image.open(path)
                images.append(pil_image)
                num_pages.append(1)

        embeddings = []
        for batch in iter_batch(
            X=images, batch_size=batch_size, desc=f"Encoding pages (bs={batch_size})"
        ):
            embeddings.extend(encode_images(batch))

        embeddings_ids = self.index.add_items(embeddings)
        current_index = 0

        for i, path in enumerate(paths):
            for page_number in range(num_pages[i]):
                page_ids_to_data[embeddings_ids[current_index]] = {
                    "path": path,
                    "image": images[current_index],
                    "page_number": page_number,
                }
                current_index += 1

        page_ids_to_data.commit()
        self.index.save(self.index_path)

        return self

    def __call__(
        self,
        queries: np.ndarray | torch.Tensor,
        k: int = 10,
    ) -> dict:
        """Query the index for the nearest neighbors of the queries embeddings.

        Parameters
        ----------
        queries_embeddings
            The queries embeddings.
        k
            The number of nearest neighbors to return.

        """

        queries_embeddings = encode_queries(queries)
        page_ids_to_data = self._load_page_ids_to_data()
        k = min(k, len(page_ids_to_data))

        n_queries = len(queries_embeddings)
        indices, distances = self.index.query(
            queries_embeddings, k, query_ef=self.ef_search
        )

        if len(indices) == 0:
            raise ValueError("Index is empty, add documents before querying.")
        documents = [
            [page_ids_to_data[str(indice)] for indice in query_indices]
            for query_indices in indices
        ]
        page_ids_to_data.close()
        return {
            "documents": documents,
            "distances": distances.reshape(n_queries, -1, k),
        }