vjeronymo2
commited on
Commit
•
828992f
1
Parent(s):
9970bea
Adding model and checkpoint
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- README.md +159 -0
- colbert/__init__.py +0 -0
- colbert/__pycache__/__init__.cpython-37.pyc +0 -0
- colbert/__pycache__/index.cpython-37.pyc +0 -0
- colbert/__pycache__/index_faiss.cpython-37.pyc +0 -0
- colbert/__pycache__/parameters.cpython-37.pyc +0 -0
- colbert/__pycache__/retrieve.cpython-37.pyc +0 -0
- colbert/__pycache__/train.cpython-37.pyc +0 -0
- colbert/evaluation/__init__.py +0 -0
- colbert/evaluation/__pycache__/__init__.cpython-37.pyc +0 -0
- colbert/evaluation/__pycache__/load_model.cpython-37.pyc +0 -0
- colbert/evaluation/__pycache__/loaders.cpython-37.pyc +0 -0
- colbert/evaluation/__pycache__/ranking_logger.cpython-37.pyc +0 -0
- colbert/evaluation/load_model.py +28 -0
- colbert/evaluation/loaders.py +196 -0
- colbert/evaluation/metrics.py +114 -0
- colbert/evaluation/ranking.py +88 -0
- colbert/evaluation/ranking_logger.py +57 -0
- colbert/evaluation/slow.py +21 -0
- colbert/index.py +59 -0
- colbert/index_faiss.py +43 -0
- colbert/indexing/__init__.py +0 -0
- colbert/indexing/__pycache__/__init__.cpython-37.pyc +0 -0
- colbert/indexing/__pycache__/encoder.cpython-37.pyc +0 -0
- colbert/indexing/__pycache__/faiss.cpython-37.pyc +0 -0
- colbert/indexing/__pycache__/faiss_index.cpython-37.pyc +0 -0
- colbert/indexing/__pycache__/faiss_index_gpu.cpython-37.pyc +0 -0
- colbert/indexing/__pycache__/index_manager.cpython-37.pyc +0 -0
- colbert/indexing/__pycache__/loaders.cpython-37.pyc +0 -0
- colbert/indexing/encoder.py +187 -0
- colbert/indexing/faiss.py +116 -0
- colbert/indexing/faiss_index.py +58 -0
- colbert/indexing/faiss_index_gpu.py +138 -0
- colbert/indexing/index_manager.py +22 -0
- colbert/indexing/loaders.py +34 -0
- colbert/modeling/__init__.py +0 -0
- colbert/modeling/__pycache__/__init__.cpython-37.pyc +0 -0
- colbert/modeling/__pycache__/colbert.cpython-37.pyc +0 -0
- colbert/modeling/__pycache__/inference.cpython-37.pyc +0 -0
- colbert/modeling/colbert.py +73 -0
- colbert/modeling/inference.py +87 -0
- colbert/modeling/tokenization/__init__.py +3 -0
- colbert/modeling/tokenization/__pycache__/__init__.cpython-37.pyc +0 -0
- colbert/modeling/tokenization/__pycache__/doc_tokenization.cpython-37.pyc +0 -0
- colbert/modeling/tokenization/__pycache__/query_tokenization.cpython-37.pyc +0 -0
- colbert/modeling/tokenization/__pycache__/utils.cpython-37.pyc +0 -0
- colbert/modeling/tokenization/doc_tokenization.py +63 -0
- colbert/modeling/tokenization/query_tokenization.py +64 -0
- colbert/modeling/tokenization/utils.py +51 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.dnn filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ColBERT
|
2 |
+
|
3 |
+
### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.
|
4 |
+
|
5 |
+
|
6 |
+
<p align="center">
|
7 |
+
<img align="center" src="docs/images/ColBERT-Framework-MaxSim-W370px.png" />
|
8 |
+
</p>
|
9 |
+
<p align="center">
|
10 |
+
<b>Figure 1:</b> ColBERT's late interaction, efficiently scoring the fine-grained similarity between a queries and a passage.
|
11 |
+
</p>
|
12 |
+
|
13 |
+
As Figure 1 illustrates, ColBERT relies on fine-grained **contextual late interaction**: it encodes each passage into a **matrix** of token-level embeddings (shown above in blue). Then at search time, it embeds every query into another matrix (shown in green) and efficiently finds passages that contextually match the query using scalable vector-similarity (`MaxSim`) operators.
|
14 |
+
|
15 |
+
These rich interactions allow ColBERT to surpass the quality of _single-vector_ representation models, while scaling efficiently to large corpora. You can read more in our papers:
|
16 |
+
|
17 |
+
* [**ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT**](https://arxiv.org/abs/2004.12832) (SIGIR'20).
|
18 |
+
* [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21; to appear).
|
19 |
+
|
20 |
+
|
21 |
+
----
|
22 |
+
|
23 |
+
## Installation
|
24 |
+
|
25 |
+
ColBERT (currently: [v0.2.0](#releases)) requires Python 3.7+ and Pytorch 1.6+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library.
|
26 |
+
|
27 |
+
We strongly recommend creating a conda environment using:
|
28 |
+
|
29 |
+
```
|
30 |
+
conda env create -f conda_env.yml
|
31 |
+
conda activate colbert-v0.2
|
32 |
+
```
|
33 |
+
|
34 |
+
If you face any problems, please [open a new issue](https://github.com/stanford-futuredata/ColBERT/issues) and we'll help you promptly!
|
35 |
+
|
36 |
+
|
37 |
+
## Overview
|
38 |
+
|
39 |
+
Using ColBERT on a dataset typically involves the following steps.
|
40 |
+
|
41 |
+
**Step 0: Preprocess your collection.** At its simplest, ColBERT works with tab-separated (TSV) files: a file (e.g., `collection.tsv`) will contain all passages and another (e.g., `queries.tsv`) will contain a set of queries for searching the collection.
|
42 |
+
|
43 |
+
**Step 1: Train a ColBERT model.** You can [train your own ColBERT model](#training) and [validate performance](#validation) on a suitable development set.
|
44 |
+
|
45 |
+
**Step 2: Index your collection.** Once you're happy with your ColBERT model, you need to [index your collection](#indexing) to permit fast retrieval. This step encodes all passages into matrices, stores them on disk, and builds data structures for efficient search.
|
46 |
+
|
47 |
+
**Step 3: Search the collection with your queries.** Given your model and index, you can [issue queries over the collection](#retrieval) to retrieve the top-k passages for each query.
|
48 |
+
|
49 |
+
Below, we illustrate these steps via an example run on the MS MARCO Passage Ranking task.
|
50 |
+
|
51 |
+
|
52 |
+
## Data
|
53 |
+
|
54 |
+
This repository works directly with a simple **tab-separated file** format to store queries, passages, and top-k ranked lists.
|
55 |
+
|
56 |
+
|
57 |
+
* Queries: each line is `qid \t query text`.
|
58 |
+
* Collection: each line is `pid \t passage text`.
|
59 |
+
* Top-k Ranking: each line is `qid \t pid \t rank`.
|
60 |
+
|
61 |
+
This works directly with the data format of the [MS MARCO Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) dataset. You will need the training triples (`triples.train.small.tar.gz`), the official top-1000 ranked lists for the dev set queries (`top1000.dev`), and the dev set relevant passages (`qrels.dev.small.tsv`). For indexing the full collection, you will also need the list of passages (`collection.tar.gz`).
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
## Training
|
66 |
+
|
67 |
+
Training requires a list of _<query, positive passage, negative passage>_ tab-separated triples.
|
68 |
+
|
69 |
+
You can supply **full-text** triples, where each line is `query text \t positive passage text \t negative passage text`. Alternatively, you can supply the query and passage **IDs** as a JSONL file `[qid, pid+, pid-]` per line, in which case you should specify `--collection path/to/collection.tsv` and `--queries path/to/queries.train.tsv`.
|
70 |
+
|
71 |
+
|
72 |
+
```
|
73 |
+
CUDA_VISIBLE_DEVICES="0,1,2,3" \
|
74 |
+
python -m torch.distributed.launch --nproc_per_node=4 -m \
|
75 |
+
colbert.train --amp --doc_maxlen 180 --mask-punctuation --bsize 32 --accum 1 \
|
76 |
+
--triples /path/to/MSMARCO/triples.train.small.tsv \
|
77 |
+
--root /root/to/experiments/ --experiment MSMARCO-psg --similarity l2 --run msmarco.psg.l2
|
78 |
+
```
|
79 |
+
|
80 |
+
You can use one or more GPUs by modifying `CUDA_VISIBLE_DEVICES` and `--nproc_per_node`.
|
81 |
+
|
82 |
+
|
83 |
+
## Validation
|
84 |
+
|
85 |
+
Before indexing into ColBERT, you can compare a few checkpoints by re-ranking a top-k set of documents per query. This will use ColBERT _on-the-fly_: it will compute document representations _during_ query evaluation.
|
86 |
+
|
87 |
+
This script requires the top-k list per query, provided as a tab-separated file whose every line contains a tuple `queryID \t passageID \t rank`, where rank is {1, 2, 3, ...} for each query. The script also accepts the format of MS MARCO's `top1000.dev` and `top1000.eval` and you can optionally supply relevance judgements (qrels) for evaluation. This is a tab-separated file whose every line has a quadruple _<query ID, 0, passage ID, 1>_, like `qrels.dev.small.tsv`.
|
88 |
+
|
89 |
+
Example command:
|
90 |
+
|
91 |
+
```
|
92 |
+
python -m colbert.test --amp --doc_maxlen 180 --mask-punctuation \
|
93 |
+
--collection /path/to/MSMARCO/collection.tsv \
|
94 |
+
--queries /path/to/MSMARCO/queries.dev.small.tsv \
|
95 |
+
--topk /path/to/MSMARCO/top1000.dev \
|
96 |
+
--checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \
|
97 |
+
--root /root/to/experiments/ --experiment MSMARCO-psg [--qrels path/to/qrels.dev.small.tsv]
|
98 |
+
```
|
99 |
+
|
100 |
+
|
101 |
+
## Indexing
|
102 |
+
|
103 |
+
For fast retrieval, indexing precomputes the ColBERT representations of passages.
|
104 |
+
|
105 |
+
Example command:
|
106 |
+
|
107 |
+
```
|
108 |
+
CUDA_VISIBLE_DEVICES="0,1,2,3" OMP_NUM_THREADS=6 \
|
109 |
+
python -m torch.distributed.launch --nproc_per_node=4 -m \
|
110 |
+
colbert.index --amp --doc_maxlen 180 --mask-punctuation --bsize 256 \
|
111 |
+
--checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \
|
112 |
+
--collection /path/to/MSMARCO/collection.tsv \
|
113 |
+
--index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \
|
114 |
+
--root /root/to/experiments/ --experiment MSMARCO-psg
|
115 |
+
```
|
116 |
+
|
117 |
+
The index created here allows you to re-rank the top-k passages retrieved by another method (e.g., BM25).
|
118 |
+
|
119 |
+
We typically recommend that you use ColBERT for **end-to-end** retrieval, where it directly finds its top-k passages from the full collection. For this, you need FAISS indexing.
|
120 |
+
|
121 |
+
|
122 |
+
#### FAISS Indexing for end-to-end retrieval
|
123 |
+
|
124 |
+
For end-to-end retrieval, you should index the document representations into [FAISS](https://github.com/facebookresearch/faiss).
|
125 |
+
|
126 |
+
```
|
127 |
+
python -m colbert.index_faiss \
|
128 |
+
--index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \
|
129 |
+
--partitions 32768 --sample 0.3 \
|
130 |
+
--root /root/to/experiments/ --experiment MSMARCO-psg
|
131 |
+
```
|
132 |
+
|
133 |
+
|
134 |
+
## Retrieval
|
135 |
+
|
136 |
+
In the simplest case, you want to retrieve from the full collection:
|
137 |
+
|
138 |
+
```
|
139 |
+
python -m colbert.retrieve \
|
140 |
+
--amp --doc_maxlen 180 --mask-punctuation --bsize 256 \
|
141 |
+
--queries /path/to/MSMARCO/queries.dev.small.tsv
|
142 |
+
--nprobe 32 --partitions 32768 --faiss_depth 1024 \
|
143 |
+
--index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \
|
144 |
+
--checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \
|
145 |
+
--root /root/to/experiments/ --experiment MSMARCO-psg
|
146 |
+
```
|
147 |
+
|
148 |
+
You may also want to re-rank a top-k set that you've retrieved before with ColBERT or with another model. For this, use `colbert.rerank` similarly and additionally pass `--topk`.
|
149 |
+
|
150 |
+
If you have a large set of queries (or want to reduce memory usage), use **batch-mode** retrieval and/or re-ranking. This can be done by passing `--batch --only_retrieval` to `colbert.retrieve` and passing `--batch --log-scores` to colbert.rerank alongside `--topk` with the `unordered.tsv` output of this retrieval run.
|
151 |
+
|
152 |
+
Some use cases (e.g., building a user-facing search engines) require more control over retrieval. For those, you typically don't want to use the command line for retrieval. Instead, you want to import our retrieval API from Python and directly work with that (e.g., to build a simple REST API). Instructions for this are coming soon, but you will just need to adapt/modify the retrieval loop in [`colbert/ranking/retrieval.py#L33`](colbert/ranking/retrieval.py#L33).
|
153 |
+
|
154 |
+
|
155 |
+
## Releases
|
156 |
+
|
157 |
+
* v0.2.0: Sep 2020
|
158 |
+
* v0.1.0: June 2020
|
159 |
+
|
colbert/__init__.py
ADDED
File without changes
|
colbert/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (120 Bytes). View file
|
|
colbert/__pycache__/index.cpython-37.pyc
ADDED
Binary file (1.61 kB). View file
|
|
colbert/__pycache__/index_faiss.cpython-37.pyc
ADDED
Binary file (1.43 kB). View file
|
|
colbert/__pycache__/parameters.cpython-37.pyc
ADDED
Binary file (354 Bytes). View file
|
|
colbert/__pycache__/retrieve.cpython-37.pyc
ADDED
Binary file (1.73 kB). View file
|
|
colbert/__pycache__/train.cpython-37.pyc
ADDED
Binary file (1.13 kB). View file
|
|
colbert/evaluation/__init__.py
ADDED
File without changes
|
colbert/evaluation/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (131 Bytes). View file
|
|
colbert/evaluation/__pycache__/load_model.cpython-37.pyc
ADDED
Binary file (932 Bytes). View file
|
|
colbert/evaluation/__pycache__/loaders.cpython-37.pyc
ADDED
Binary file (6.09 kB). View file
|
|
colbert/evaluation/__pycache__/ranking_logger.cpython-37.pyc
ADDED
Binary file (2.12 kB). View file
|
|
colbert/evaluation/load_model.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ujson
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
|
6 |
+
from collections import defaultdict, OrderedDict
|
7 |
+
|
8 |
+
from colbert.parameters import DEVICE
|
9 |
+
from colbert.modeling.colbert import ColBERT
|
10 |
+
from colbert.utils.utils import print_message, load_checkpoint
|
11 |
+
|
12 |
+
|
13 |
+
def load_model(args, do_print=True):
|
14 |
+
colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased',
|
15 |
+
query_maxlen=args.query_maxlen,
|
16 |
+
doc_maxlen=args.doc_maxlen,
|
17 |
+
dim=args.dim,
|
18 |
+
similarity_metric=args.similarity,
|
19 |
+
mask_punctuation=args.mask_punctuation)
|
20 |
+
colbert = colbert.to(DEVICE)
|
21 |
+
|
22 |
+
print_message("#> Loading model checkpoint.", condition=do_print)
|
23 |
+
|
24 |
+
checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)
|
25 |
+
|
26 |
+
colbert.eval()
|
27 |
+
|
28 |
+
return colbert, checkpoint
|
colbert/evaluation/loaders.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ujson
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
|
6 |
+
from collections import defaultdict, OrderedDict
|
7 |
+
|
8 |
+
from colbert.parameters import DEVICE
|
9 |
+
from colbert.modeling.colbert import ColBERT
|
10 |
+
from colbert.utils.utils import print_message, load_checkpoint
|
11 |
+
from colbert.evaluation.load_model import load_model
|
12 |
+
from colbert.utils.runs import Run
|
13 |
+
|
14 |
+
|
15 |
+
def load_queries(queries_path):
|
16 |
+
queries = OrderedDict()
|
17 |
+
|
18 |
+
print_message("#> Loading the queries from", queries_path, "...")
|
19 |
+
|
20 |
+
with open(queries_path) as f:
|
21 |
+
for line in f:
|
22 |
+
qid, query, *_ = line.strip().split('\t')
|
23 |
+
qid = int(qid)
|
24 |
+
|
25 |
+
assert (qid not in queries), ("Query QID", qid, "is repeated!")
|
26 |
+
queries[qid] = query
|
27 |
+
|
28 |
+
print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")
|
29 |
+
|
30 |
+
return queries
|
31 |
+
|
32 |
+
|
33 |
+
def load_qrels(qrels_path):
|
34 |
+
if qrels_path is None:
|
35 |
+
return None
|
36 |
+
|
37 |
+
print_message("#> Loading qrels from", qrels_path, "...")
|
38 |
+
|
39 |
+
qrels = OrderedDict()
|
40 |
+
with open(qrels_path, mode='r', encoding="utf-8") as f:
|
41 |
+
for line in f:
|
42 |
+
qid, x, pid, y = map(int, line.strip().split('\t'))
|
43 |
+
assert x == 0 and y == 1
|
44 |
+
qrels[qid] = qrels.get(qid, [])
|
45 |
+
qrels[qid].append(pid)
|
46 |
+
|
47 |
+
assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)
|
48 |
+
|
49 |
+
avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)
|
50 |
+
|
51 |
+
print_message("#> Loaded qrels for", len(qrels), "unique queries with",
|
52 |
+
avg_positive, "positives per query on average.\n")
|
53 |
+
|
54 |
+
return qrels
|
55 |
+
|
56 |
+
|
57 |
+
def load_topK(topK_path):
|
58 |
+
queries = OrderedDict()
|
59 |
+
topK_docs = OrderedDict()
|
60 |
+
topK_pids = OrderedDict()
|
61 |
+
|
62 |
+
print_message("#> Loading the top-k per query from", topK_path, "...")
|
63 |
+
|
64 |
+
with open(topK_path) as f:
|
65 |
+
for line_idx, line in enumerate(f):
|
66 |
+
if line_idx and line_idx % (10*1000*1000) == 0:
|
67 |
+
print(line_idx, end=' ', flush=True)
|
68 |
+
|
69 |
+
qid, pid, query, passage = line.split('\t')
|
70 |
+
qid, pid = int(qid), int(pid)
|
71 |
+
|
72 |
+
assert (qid not in queries) or (queries[qid] == query)
|
73 |
+
queries[qid] = query
|
74 |
+
topK_docs[qid] = topK_docs.get(qid, [])
|
75 |
+
topK_docs[qid].append(passage)
|
76 |
+
topK_pids[qid] = topK_pids.get(qid, [])
|
77 |
+
topK_pids[qid].append(pid)
|
78 |
+
|
79 |
+
print()
|
80 |
+
|
81 |
+
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
|
82 |
+
|
83 |
+
Ks = [len(topK_pids[qid]) for qid in topK_pids]
|
84 |
+
|
85 |
+
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
|
86 |
+
print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")
|
87 |
+
|
88 |
+
return queries, topK_docs, topK_pids
|
89 |
+
|
90 |
+
|
91 |
+
def load_topK_pids(topK_path, qrels):
|
92 |
+
topK_pids = defaultdict(list)
|
93 |
+
topK_positives = defaultdict(list)
|
94 |
+
|
95 |
+
print_message("#> Loading the top-k PIDs per query from", topK_path, "...")
|
96 |
+
|
97 |
+
with open(topK_path) as f:
|
98 |
+
for line_idx, line in enumerate(f):
|
99 |
+
if line_idx and line_idx % (10*1000*1000) == 0:
|
100 |
+
print(line_idx, end=' ', flush=True)
|
101 |
+
|
102 |
+
qid, pid, *rest = line.strip().split('\t')
|
103 |
+
qid, pid = int(qid), int(pid)
|
104 |
+
|
105 |
+
topK_pids[qid].append(pid)
|
106 |
+
|
107 |
+
assert len(rest) in [1, 2, 3]
|
108 |
+
|
109 |
+
if len(rest) > 1:
|
110 |
+
*_, label = rest
|
111 |
+
label = int(label)
|
112 |
+
assert label in [0, 1]
|
113 |
+
|
114 |
+
if label >= 1:
|
115 |
+
topK_positives[qid].append(pid)
|
116 |
+
|
117 |
+
print()
|
118 |
+
|
119 |
+
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
|
120 |
+
assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)
|
121 |
+
|
122 |
+
# Make them sets for fast lookups later
|
123 |
+
topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}
|
124 |
+
|
125 |
+
Ks = [len(topK_pids[qid]) for qid in topK_pids]
|
126 |
+
|
127 |
+
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
|
128 |
+
print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")
|
129 |
+
|
130 |
+
if len(topK_positives) == 0:
|
131 |
+
topK_positives = None
|
132 |
+
else:
|
133 |
+
assert len(topK_pids) >= len(topK_positives)
|
134 |
+
|
135 |
+
for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
|
136 |
+
topK_positives[qid] = []
|
137 |
+
|
138 |
+
assert len(topK_pids) == len(topK_positives)
|
139 |
+
|
140 |
+
avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)
|
141 |
+
|
142 |
+
print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
|
143 |
+
avg_positive, "positives per query on average.\n")
|
144 |
+
|
145 |
+
assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"
|
146 |
+
|
147 |
+
if topK_positives is None:
|
148 |
+
topK_positives = qrels
|
149 |
+
|
150 |
+
return topK_pids, topK_positives
|
151 |
+
|
152 |
+
|
153 |
+
def load_collection(collection_path):
|
154 |
+
print_message("#> Loading collection...")
|
155 |
+
|
156 |
+
collection = []
|
157 |
+
|
158 |
+
with open(collection_path) as f:
|
159 |
+
for line_idx, line in enumerate(f):
|
160 |
+
if line_idx % (1000*1000) == 0:
|
161 |
+
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)
|
162 |
+
|
163 |
+
pid, passage, *rest = line.strip().split('\t')
|
164 |
+
assert pid == 'id' or int(pid) == line_idx
|
165 |
+
|
166 |
+
if len(rest) >= 1:
|
167 |
+
title = rest[0]
|
168 |
+
passage = title + ' | ' + passage
|
169 |
+
|
170 |
+
collection.append(passage)
|
171 |
+
|
172 |
+
print()
|
173 |
+
|
174 |
+
return collection
|
175 |
+
|
176 |
+
|
177 |
+
def load_colbert(args, do_print=True):
|
178 |
+
colbert, checkpoint = load_model(args, do_print)
|
179 |
+
|
180 |
+
# TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
|
181 |
+
# I.e., not their purely (i.e., training) default values.
|
182 |
+
|
183 |
+
for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
|
184 |
+
if 'arguments' in checkpoint and hasattr(args, k):
|
185 |
+
if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
|
186 |
+
a, b = checkpoint['arguments'][k], getattr(args, k)
|
187 |
+
Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")
|
188 |
+
|
189 |
+
if 'arguments' in checkpoint:
|
190 |
+
if args.rank < 1:
|
191 |
+
print(ujson.dumps(checkpoint['arguments'], indent=4))
|
192 |
+
|
193 |
+
if do_print:
|
194 |
+
print('\n')
|
195 |
+
|
196 |
+
return colbert, checkpoint
|
colbert/evaluation/metrics.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ujson
|
2 |
+
|
3 |
+
from collections import defaultdict
|
4 |
+
from colbert.utils.runs import Run
|
5 |
+
|
6 |
+
|
7 |
+
class Metrics:
|
8 |
+
def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
|
9 |
+
self.results = {}
|
10 |
+
self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
|
11 |
+
self.recall_sums = {depth: 0.0 for depth in recall_depths}
|
12 |
+
self.success_sums = {depth: 0.0 for depth in success_depths}
|
13 |
+
self.total_queries = total_queries
|
14 |
+
|
15 |
+
self.max_query_idx = -1
|
16 |
+
self.num_queries_added = 0
|
17 |
+
|
18 |
+
def add(self, query_idx, query_key, ranking, gold_positives):
|
19 |
+
self.num_queries_added += 1
|
20 |
+
|
21 |
+
assert query_key not in self.results
|
22 |
+
assert len(self.results) <= query_idx
|
23 |
+
assert len(set(gold_positives)) == len(gold_positives)
|
24 |
+
assert len(set([pid for _, pid, _ in ranking])) == len(ranking)
|
25 |
+
|
26 |
+
self.results[query_key] = ranking
|
27 |
+
|
28 |
+
positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]
|
29 |
+
|
30 |
+
if len(positives) == 0:
|
31 |
+
return
|
32 |
+
|
33 |
+
for depth in self.mrr_sums:
|
34 |
+
first_positive = positives[0]
|
35 |
+
self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0
|
36 |
+
|
37 |
+
for depth in self.success_sums:
|
38 |
+
first_positive = positives[0]
|
39 |
+
self.success_sums[depth] += 1.0 if first_positive < depth else 0.0
|
40 |
+
|
41 |
+
for depth in self.recall_sums:
|
42 |
+
num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
|
43 |
+
self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)
|
44 |
+
|
45 |
+
def print_metrics(self, query_idx):
|
46 |
+
for depth in sorted(self.mrr_sums):
|
47 |
+
print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))
|
48 |
+
|
49 |
+
for depth in sorted(self.success_sums):
|
50 |
+
print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))
|
51 |
+
|
52 |
+
for depth in sorted(self.recall_sums):
|
53 |
+
print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))
|
54 |
+
|
55 |
+
def log(self, query_idx):
|
56 |
+
assert query_idx >= self.max_query_idx
|
57 |
+
self.max_query_idx = query_idx
|
58 |
+
|
59 |
+
Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
|
60 |
+
Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)
|
61 |
+
|
62 |
+
for depth in sorted(self.mrr_sums):
|
63 |
+
score = self.mrr_sums[depth] / (query_idx+1.0)
|
64 |
+
Run.log_metric("ranking/MRR." + str(depth), score, query_idx)
|
65 |
+
|
66 |
+
for depth in sorted(self.success_sums):
|
67 |
+
score = self.success_sums[depth] / (query_idx+1.0)
|
68 |
+
Run.log_metric("ranking/Success." + str(depth), score, query_idx)
|
69 |
+
|
70 |
+
for depth in sorted(self.recall_sums):
|
71 |
+
score = self.recall_sums[depth] / (query_idx+1.0)
|
72 |
+
Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
|
73 |
+
|
74 |
+
def output_final_metrics(self, path, query_idx, num_queries):
|
75 |
+
assert query_idx + 1 == num_queries
|
76 |
+
assert num_queries == self.total_queries
|
77 |
+
|
78 |
+
if self.max_query_idx < query_idx:
|
79 |
+
self.log(query_idx)
|
80 |
+
|
81 |
+
self.print_metrics(query_idx)
|
82 |
+
|
83 |
+
output = defaultdict(dict)
|
84 |
+
|
85 |
+
for depth in sorted(self.mrr_sums):
|
86 |
+
score = self.mrr_sums[depth] / (query_idx+1.0)
|
87 |
+
output['mrr'][depth] = score
|
88 |
+
|
89 |
+
for depth in sorted(self.success_sums):
|
90 |
+
score = self.success_sums[depth] / (query_idx+1.0)
|
91 |
+
output['success'][depth] = score
|
92 |
+
|
93 |
+
for depth in sorted(self.recall_sums):
|
94 |
+
score = self.recall_sums[depth] / (query_idx+1.0)
|
95 |
+
output['recall'][depth] = score
|
96 |
+
|
97 |
+
with open(path, 'w') as f:
|
98 |
+
ujson.dump(output, f, indent=4)
|
99 |
+
f.write('\n')
|
100 |
+
|
101 |
+
|
102 |
+
def evaluate_recall(qrels, queries, topK_pids):
|
103 |
+
if qrels is None:
|
104 |
+
return
|
105 |
+
|
106 |
+
assert set(qrels.keys()) == set(queries.keys())
|
107 |
+
recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
|
108 |
+
for qid in qrels]
|
109 |
+
recall_at_k = sum(recall_at_k) / len(qrels)
|
110 |
+
recall_at_k = round(recall_at_k, 3)
|
111 |
+
print("Recall @ maximum depth =", recall_at_k)
|
112 |
+
|
113 |
+
|
114 |
+
# TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output.
|
colbert/evaluation/ranking.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from itertools import accumulate
|
8 |
+
from math import ceil
|
9 |
+
|
10 |
+
from colbert.utils.runs import Run
|
11 |
+
from colbert.utils.utils import print_message
|
12 |
+
|
13 |
+
from colbert.evaluation.metrics import Metrics
|
14 |
+
from colbert.evaluation.ranking_logger import RankingLogger
|
15 |
+
from colbert.modeling.inference import ModelInference
|
16 |
+
|
17 |
+
from colbert.evaluation.slow import slow_rerank
|
18 |
+
|
19 |
+
|
20 |
+
def evaluate(args):
|
21 |
+
args.inference = ModelInference(args.colbert, amp=args.amp)
|
22 |
+
qrels, queries, topK_pids = args.qrels, args.queries, args.topK_pids
|
23 |
+
|
24 |
+
depth = args.depth
|
25 |
+
collection = args.collection
|
26 |
+
if collection is None:
|
27 |
+
topK_docs = args.topK_docs
|
28 |
+
|
29 |
+
def qid2passages(qid):
|
30 |
+
if collection is not None:
|
31 |
+
return [collection[pid] for pid in topK_pids[qid][:depth]]
|
32 |
+
else:
|
33 |
+
return topK_docs[qid][:depth]
|
34 |
+
|
35 |
+
metrics = Metrics(mrr_depths={10, 100}, recall_depths={50, 200, 1000},
|
36 |
+
success_depths={5, 10, 20, 50, 100, 1000},
|
37 |
+
total_queries=len(queries))
|
38 |
+
|
39 |
+
ranking_logger = RankingLogger(Run.path, qrels=qrels)
|
40 |
+
|
41 |
+
args.milliseconds = []
|
42 |
+
|
43 |
+
with ranking_logger.context('ranking.tsv', also_save_annotations=(qrels is not None)) as rlogger:
|
44 |
+
with torch.no_grad():
|
45 |
+
keys = sorted(list(queries.keys()))
|
46 |
+
random.shuffle(keys)
|
47 |
+
|
48 |
+
for query_idx, qid in enumerate(keys):
|
49 |
+
query = queries[qid]
|
50 |
+
|
51 |
+
print_message(query_idx, qid, query, '\n')
|
52 |
+
|
53 |
+
if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0:
|
54 |
+
continue
|
55 |
+
|
56 |
+
ranking = slow_rerank(args, query, topK_pids[qid], qid2passages(qid))
|
57 |
+
|
58 |
+
rlogger.log(qid, ranking, [0, 1])
|
59 |
+
|
60 |
+
if qrels:
|
61 |
+
metrics.add(query_idx, qid, ranking, qrels[qid])
|
62 |
+
|
63 |
+
for i, (score, pid, passage) in enumerate(ranking):
|
64 |
+
if pid in qrels[qid]:
|
65 |
+
print("\n#> Found", pid, "at position", i+1, "with score", score)
|
66 |
+
print(passage)
|
67 |
+
break
|
68 |
+
|
69 |
+
metrics.print_metrics(query_idx)
|
70 |
+
metrics.log(query_idx)
|
71 |
+
|
72 |
+
print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n')
|
73 |
+
print("rlogger.filename =", rlogger.filename)
|
74 |
+
|
75 |
+
if len(args.milliseconds) > 1:
|
76 |
+
print('Slow-Ranking Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))
|
77 |
+
|
78 |
+
print("\n\n")
|
79 |
+
|
80 |
+
print("\n\n")
|
81 |
+
# print('Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))
|
82 |
+
print("\n\n")
|
83 |
+
|
84 |
+
print('\n\n')
|
85 |
+
if qrels:
|
86 |
+
assert query_idx + 1 == len(keys) == len(set(keys))
|
87 |
+
metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), query_idx, len(queries))
|
88 |
+
print('\n\n')
|
colbert/evaluation/ranking_logger.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from colbert.utils.utils import print_message, NullContextManager
|
5 |
+
from colbert.utils.runs import Run
|
6 |
+
|
7 |
+
|
8 |
+
class RankingLogger():
|
9 |
+
def __init__(self, directory, qrels=None, log_scores=False):
|
10 |
+
self.directory = directory
|
11 |
+
self.qrels = qrels
|
12 |
+
self.filename, self.also_save_annotations = None, None
|
13 |
+
self.log_scores = log_scores
|
14 |
+
|
15 |
+
@contextmanager
|
16 |
+
def context(self, filename, also_save_annotations=False):
|
17 |
+
assert self.filename is None
|
18 |
+
assert self.also_save_annotations is None
|
19 |
+
|
20 |
+
filename = os.path.join(self.directory, filename)
|
21 |
+
self.filename, self.also_save_annotations = filename, also_save_annotations
|
22 |
+
|
23 |
+
print_message("#> Logging ranked lists to {}".format(self.filename))
|
24 |
+
|
25 |
+
with open(filename, 'w') as f:
|
26 |
+
self.f = f
|
27 |
+
with (open(filename + '.annotated', 'w') if also_save_annotations else NullContextManager()) as g:
|
28 |
+
self.g = g
|
29 |
+
try:
|
30 |
+
yield self
|
31 |
+
finally:
|
32 |
+
pass
|
33 |
+
|
34 |
+
def log(self, qid, ranking, is_ranked=True, print_positions=[]):
|
35 |
+
print_positions = set(print_positions)
|
36 |
+
|
37 |
+
f_buffer = []
|
38 |
+
g_buffer = []
|
39 |
+
|
40 |
+
for rank, (score, pid, passage) in enumerate(ranking):
|
41 |
+
is_relevant = self.qrels and int(pid in self.qrels[qid])
|
42 |
+
rank = rank+1 if is_ranked else -1
|
43 |
+
|
44 |
+
possibly_score = [score] if self.log_scores else []
|
45 |
+
|
46 |
+
f_buffer.append('\t'.join([str(x) for x in [qid, pid, rank] + possibly_score]) + "\n")
|
47 |
+
if self.g:
|
48 |
+
g_buffer.append('\t'.join([str(x) for x in [qid, pid, rank, is_relevant]]) + "\n")
|
49 |
+
|
50 |
+
if rank in print_positions:
|
51 |
+
prefix = "** " if is_relevant else ""
|
52 |
+
prefix += str(rank)
|
53 |
+
print("#> ( QID {} ) ".format(qid) + prefix + ") ", pid, ":", score, ' ', passage)
|
54 |
+
|
55 |
+
self.f.write(''.join(f_buffer))
|
56 |
+
if self.g:
|
57 |
+
self.g.write(''.join(g_buffer))
|
colbert/evaluation/slow.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
def slow_rerank(args, query, pids, passages):
|
4 |
+
colbert = args.colbert
|
5 |
+
inference = args.inference
|
6 |
+
|
7 |
+
Q = inference.queryFromText([query])
|
8 |
+
|
9 |
+
D_ = inference.docFromText(passages, bsize=args.bsize)
|
10 |
+
scores = colbert.score(Q, D_).cpu()
|
11 |
+
|
12 |
+
scores = scores.sort(descending=True)
|
13 |
+
ranked = scores.indices.tolist()
|
14 |
+
|
15 |
+
ranked_scores = scores.values.tolist()
|
16 |
+
ranked_pids = [pids[position] for position in ranked]
|
17 |
+
ranked_passages = [passages[position] for position in ranked]
|
18 |
+
|
19 |
+
assert len(ranked_pids) == len(set(ranked_pids))
|
20 |
+
|
21 |
+
return list(zip(ranked_scores, ranked_pids, ranked_passages))
|
colbert/index.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ujson
|
3 |
+
import random
|
4 |
+
|
5 |
+
from colbert.utils.runs import Run
|
6 |
+
from colbert.utils.parser import Arguments
|
7 |
+
import colbert.utils.distributed as distributed
|
8 |
+
|
9 |
+
from colbert.utils.utils import print_message, create_directory
|
10 |
+
from colbert.indexing.encoder import CollectionEncoder
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
random.seed(12345)
|
15 |
+
|
16 |
+
parser = Arguments(description='Precomputing document representations with ColBERT.')
|
17 |
+
|
18 |
+
parser.add_model_parameters()
|
19 |
+
parser.add_model_inference_parameters()
|
20 |
+
parser.add_indexing_input()
|
21 |
+
|
22 |
+
parser.add_argument('--chunksize', dest='chunksize', default=6.0, required=False, type=float) # in GiBs
|
23 |
+
|
24 |
+
args = parser.parse()
|
25 |
+
|
26 |
+
with Run.context():
|
27 |
+
args.index_path = os.path.join(args.index_root, args.index_name)
|
28 |
+
assert not os.path.exists(args.index_path), args.index_path
|
29 |
+
|
30 |
+
distributed.barrier(args.rank)
|
31 |
+
|
32 |
+
if args.rank < 1:
|
33 |
+
create_directory(args.index_root)
|
34 |
+
create_directory(args.index_path)
|
35 |
+
|
36 |
+
distributed.barrier(args.rank)
|
37 |
+
|
38 |
+
process_idx = max(0, args.rank)
|
39 |
+
encoder = CollectionEncoder(args, process_idx=process_idx, num_processes=args.nranks)
|
40 |
+
encoder.encode()
|
41 |
+
|
42 |
+
distributed.barrier(args.rank)
|
43 |
+
|
44 |
+
# Save metadata.
|
45 |
+
if args.rank < 1:
|
46 |
+
metadata_path = os.path.join(args.index_path, 'metadata.json')
|
47 |
+
print_message("Saving (the following) metadata to", metadata_path, "..")
|
48 |
+
print(args.input_arguments)
|
49 |
+
|
50 |
+
with open(metadata_path, 'w') as output_metadata:
|
51 |
+
ujson.dump(args.input_arguments.__dict__, output_metadata)
|
52 |
+
|
53 |
+
distributed.barrier(args.rank)
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
main()
|
58 |
+
|
59 |
+
# TODO: Add resume functionality
|
colbert/index_faiss.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import math
|
4 |
+
|
5 |
+
from colbert.utils.runs import Run
|
6 |
+
from colbert.utils.parser import Arguments
|
7 |
+
from colbert.indexing.faiss import index_faiss
|
8 |
+
from colbert.indexing.loaders import load_doclens
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
random.seed(12345)
|
13 |
+
|
14 |
+
parser = Arguments(description='Faiss indexing for end-to-end retrieval with ColBERT.')
|
15 |
+
parser.add_index_use_input()
|
16 |
+
|
17 |
+
parser.add_argument('--sample', dest='sample', default=None, type=float)
|
18 |
+
parser.add_argument('--slices', dest='slices', default=1, type=int)
|
19 |
+
|
20 |
+
args = parser.parse()
|
21 |
+
assert args.slices >= 1
|
22 |
+
assert args.sample is None or (0.0 < args.sample < 1.0), args.sample
|
23 |
+
|
24 |
+
with Run.context():
|
25 |
+
args.index_path = os.path.join(args.index_root, args.index_name)
|
26 |
+
assert os.path.exists(args.index_path), args.index_path
|
27 |
+
|
28 |
+
num_embeddings = sum(load_doclens(args.index_path))
|
29 |
+
print("#> num_embeddings =", num_embeddings)
|
30 |
+
|
31 |
+
if args.partitions is None:
|
32 |
+
args.partitions = 1 << math.ceil(math.log2(8 * math.sqrt(num_embeddings)))
|
33 |
+
print('\n\n')
|
34 |
+
Run.warn("You did not specify --partitions!")
|
35 |
+
Run.warn("Default computation chooses", args.partitions,
|
36 |
+
"partitions (for {} embeddings)".format(num_embeddings))
|
37 |
+
print('\n\n')
|
38 |
+
|
39 |
+
index_faiss(args)
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
main()
|
colbert/indexing/__init__.py
ADDED
File without changes
|
colbert/indexing/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (129 Bytes). View file
|
|
colbert/indexing/__pycache__/encoder.cpython-37.pyc
ADDED
Binary file (5.92 kB). View file
|
|
colbert/indexing/__pycache__/faiss.cpython-37.pyc
ADDED
Binary file (3.44 kB). View file
|
|
colbert/indexing/__pycache__/faiss_index.cpython-37.pyc
ADDED
Binary file (1.92 kB). View file
|
|
colbert/indexing/__pycache__/faiss_index_gpu.cpython-37.pyc
ADDED
Binary file (4.11 kB). View file
|
|
colbert/indexing/__pycache__/index_manager.cpython-37.pyc
ADDED
Binary file (880 Bytes). View file
|
|
colbert/indexing/__pycache__/loaders.cpython-37.pyc
ADDED
Binary file (1.76 kB). View file
|
|
colbert/indexing/encoder.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import ujson
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import itertools
|
8 |
+
import threading
|
9 |
+
import queue
|
10 |
+
|
11 |
+
from colbert.modeling.inference import ModelInference
|
12 |
+
from colbert.evaluation.loaders import load_colbert
|
13 |
+
from colbert.utils.utils import print_message
|
14 |
+
|
15 |
+
from colbert.indexing.index_manager import IndexManager
|
16 |
+
|
17 |
+
|
18 |
+
class CollectionEncoder():
|
19 |
+
def __init__(self, args, process_idx, num_processes):
|
20 |
+
self.args = args
|
21 |
+
self.collection = args.collection
|
22 |
+
self.process_idx = process_idx
|
23 |
+
self.num_processes = num_processes
|
24 |
+
|
25 |
+
assert 0.5 <= args.chunksize <= 128.0
|
26 |
+
max_bytes_per_file = args.chunksize * (1024*1024*1024)
|
27 |
+
|
28 |
+
max_bytes_per_doc = (self.args.doc_maxlen * self.args.dim * 2.0)
|
29 |
+
|
30 |
+
# Determine subset sizes for output
|
31 |
+
minimum_subset_size = 10_000
|
32 |
+
maximum_subset_size = max_bytes_per_file / max_bytes_per_doc
|
33 |
+
maximum_subset_size = max(minimum_subset_size, maximum_subset_size)
|
34 |
+
self.possible_subset_sizes = [int(maximum_subset_size)]
|
35 |
+
|
36 |
+
self.print_main("#> Local args.bsize =", args.bsize)
|
37 |
+
self.print_main("#> args.index_root =", args.index_root)
|
38 |
+
self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")
|
39 |
+
|
40 |
+
self._load_model()
|
41 |
+
self.indexmgr = IndexManager(args.dim)
|
42 |
+
self.iterator = self._initialize_iterator()
|
43 |
+
|
44 |
+
def _initialize_iterator(self):
|
45 |
+
return open(self.collection)
|
46 |
+
|
47 |
+
def _saver_thread(self):
|
48 |
+
for args in iter(self.saver_queue.get, None):
|
49 |
+
self._save_batch(*args)
|
50 |
+
|
51 |
+
def _load_model(self):
|
52 |
+
self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0))
|
53 |
+
self.colbert = self.colbert.cuda()
|
54 |
+
self.colbert.eval()
|
55 |
+
|
56 |
+
self.inference = ModelInference(self.colbert, amp=self.args.amp)
|
57 |
+
|
58 |
+
def encode(self):
|
59 |
+
self.saver_queue = queue.Queue(maxsize=3)
|
60 |
+
thread = threading.Thread(target=self._saver_thread)
|
61 |
+
thread.start()
|
62 |
+
|
63 |
+
t0 = time.time()
|
64 |
+
local_docs_processed = 0
|
65 |
+
|
66 |
+
for batch_idx, (offset, lines, owner) in enumerate(self._batch_passages(self.iterator)):
|
67 |
+
if owner != self.process_idx:
|
68 |
+
continue
|
69 |
+
|
70 |
+
t1 = time.time()
|
71 |
+
batch = self._preprocess_batch(offset, lines)
|
72 |
+
embs, doclens = self._encode_batch(batch_idx, batch)
|
73 |
+
|
74 |
+
t2 = time.time()
|
75 |
+
self.saver_queue.put((batch_idx, embs, offset, doclens))
|
76 |
+
|
77 |
+
t3 = time.time()
|
78 |
+
local_docs_processed += len(lines)
|
79 |
+
overall_throughput = compute_throughput(local_docs_processed, t0, t3)
|
80 |
+
this_encoding_throughput = compute_throughput(len(lines), t1, t2)
|
81 |
+
this_saving_throughput = compute_throughput(len(lines), t2, t3)
|
82 |
+
|
83 |
+
self.print(f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t'
|
84 |
+
f'Passages/min: {overall_throughput} (overall), ',
|
85 |
+
f'{this_encoding_throughput} (this encoding), ',
|
86 |
+
f'{this_saving_throughput} (this saving)')
|
87 |
+
self.saver_queue.put(None)
|
88 |
+
|
89 |
+
self.print("#> Joining saver thread.")
|
90 |
+
thread.join()
|
91 |
+
|
92 |
+
def _batch_passages(self, fi):
|
93 |
+
"""
|
94 |
+
Must use the same seed across processes!
|
95 |
+
"""
|
96 |
+
np.random.seed(0)
|
97 |
+
|
98 |
+
offset = 0
|
99 |
+
for owner in itertools.cycle(range(self.num_processes)):
|
100 |
+
batch_size = np.random.choice(self.possible_subset_sizes)
|
101 |
+
|
102 |
+
L = [line for _, line in zip(range(batch_size), fi)]
|
103 |
+
|
104 |
+
if len(L) == 0:
|
105 |
+
break # EOF
|
106 |
+
|
107 |
+
yield (offset, L, owner)
|
108 |
+
offset += len(L)
|
109 |
+
|
110 |
+
if len(L) < batch_size:
|
111 |
+
break # EOF
|
112 |
+
|
113 |
+
self.print("[NOTE] Done with local share.")
|
114 |
+
|
115 |
+
return
|
116 |
+
|
117 |
+
def _preprocess_batch(self, offset, lines):
|
118 |
+
endpos = offset + len(lines)
|
119 |
+
|
120 |
+
batch = []
|
121 |
+
|
122 |
+
for line_idx, line in zip(range(offset, endpos), lines):
|
123 |
+
line_parts = line.strip().split('\t')
|
124 |
+
|
125 |
+
pid, passage, *other = line_parts
|
126 |
+
|
127 |
+
assert len(passage) >= 1
|
128 |
+
|
129 |
+
if len(other) >= 1:
|
130 |
+
title, *_ = other
|
131 |
+
passage = title + ' | ' + passage
|
132 |
+
|
133 |
+
batch.append(passage)
|
134 |
+
|
135 |
+
# assert pid == 'id' or int(pid) == line_idx
|
136 |
+
|
137 |
+
return batch
|
138 |
+
|
139 |
+
def _encode_batch(self, batch_idx, batch):
|
140 |
+
with torch.no_grad():
|
141 |
+
embs = self.inference.docFromText(batch, bsize=self.args.bsize, keep_dims=False)
|
142 |
+
assert type(embs) is list
|
143 |
+
assert len(embs) == len(batch)
|
144 |
+
|
145 |
+
local_doclens = [d.size(0) for d in embs]
|
146 |
+
embs = torch.cat(embs)
|
147 |
+
|
148 |
+
return embs, local_doclens
|
149 |
+
|
150 |
+
def _save_batch(self, batch_idx, embs, offset, doclens):
|
151 |
+
start_time = time.time()
|
152 |
+
|
153 |
+
output_path = os.path.join(self.args.index_path, "{}.pt".format(batch_idx))
|
154 |
+
output_sample_path = os.path.join(self.args.index_path, "{}.sample".format(batch_idx))
|
155 |
+
doclens_path = os.path.join(self.args.index_path, 'doclens.{}.json'.format(batch_idx))
|
156 |
+
|
157 |
+
# Save the embeddings.
|
158 |
+
self.indexmgr.save(embs, output_path)
|
159 |
+
self.indexmgr.save(embs[torch.randint(0, high=embs.size(0), size=(embs.size(0) // 20,))], output_sample_path)
|
160 |
+
|
161 |
+
# Save the doclens.
|
162 |
+
with open(doclens_path, 'w') as output_doclens:
|
163 |
+
ujson.dump(doclens, output_doclens)
|
164 |
+
|
165 |
+
throughput = compute_throughput(len(doclens), start_time, time.time())
|
166 |
+
self.print_main("#> Saved batch #{} to {} \t\t".format(batch_idx, output_path),
|
167 |
+
"Saving Throughput =", throughput, "passages per minute.\n")
|
168 |
+
|
169 |
+
def print(self, *args):
|
170 |
+
print_message("[" + str(self.process_idx) + "]", "\t\t", *args)
|
171 |
+
|
172 |
+
def print_main(self, *args):
|
173 |
+
if self.process_idx == 0:
|
174 |
+
self.print(*args)
|
175 |
+
|
176 |
+
|
177 |
+
def compute_throughput(size, t0, t1):
|
178 |
+
throughput = size / (t1 - t0) * 60
|
179 |
+
|
180 |
+
if throughput > 1000 * 1000:
|
181 |
+
throughput = throughput / (1000*1000)
|
182 |
+
throughput = round(throughput, 1)
|
183 |
+
return '{}M'.format(throughput)
|
184 |
+
|
185 |
+
throughput = throughput / (1000)
|
186 |
+
throughput = round(throughput, 1)
|
187 |
+
return '{}k'.format(throughput)
|
colbert/indexing/faiss.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import faiss
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import threading
|
8 |
+
import queue
|
9 |
+
|
10 |
+
from colbert.utils.utils import print_message, grouper
|
11 |
+
from colbert.indexing.loaders import get_parts
|
12 |
+
from colbert.indexing.index_manager import load_index_part
|
13 |
+
from colbert.indexing.faiss_index import FaissIndex
|
14 |
+
|
15 |
+
|
16 |
+
def get_faiss_index_name(args, offset=None, endpos=None):
|
17 |
+
partitions_info = '' if args.partitions is None else f'.{args.partitions}'
|
18 |
+
range_info = '' if offset is None else f'.{offset}-{endpos}'
|
19 |
+
|
20 |
+
return f'ivfpq{partitions_info}{range_info}.faiss'
|
21 |
+
|
22 |
+
|
23 |
+
def load_sample(samples_paths, sample_fraction=None):
|
24 |
+
sample = []
|
25 |
+
|
26 |
+
for filename in samples_paths:
|
27 |
+
print_message(f"#> Loading {filename} ...")
|
28 |
+
part = load_index_part(filename)
|
29 |
+
if sample_fraction:
|
30 |
+
part = part[torch.randint(0, high=part.size(0), size=(int(part.size(0) * sample_fraction),))]
|
31 |
+
sample.append(part)
|
32 |
+
|
33 |
+
sample = torch.cat(sample).float().numpy()
|
34 |
+
|
35 |
+
print("#> Sample has shape", sample.shape)
|
36 |
+
|
37 |
+
return sample
|
38 |
+
|
39 |
+
|
40 |
+
def prepare_faiss_index(slice_samples_paths, partitions, sample_fraction=None):
|
41 |
+
training_sample = load_sample(slice_samples_paths, sample_fraction=sample_fraction)
|
42 |
+
|
43 |
+
dim = training_sample.shape[-1]
|
44 |
+
index = FaissIndex(dim, partitions)
|
45 |
+
|
46 |
+
print_message("#> Training with the vectors...")
|
47 |
+
|
48 |
+
index.train(training_sample)
|
49 |
+
|
50 |
+
print_message("Done training!\n")
|
51 |
+
|
52 |
+
return index
|
53 |
+
|
54 |
+
|
55 |
+
SPAN = 3
|
56 |
+
|
57 |
+
|
58 |
+
def index_faiss(args):
|
59 |
+
print_message("#> Starting..")
|
60 |
+
|
61 |
+
parts, parts_paths, samples_paths = get_parts(args.index_path)
|
62 |
+
|
63 |
+
if args.sample is not None:
|
64 |
+
assert args.sample, args.sample
|
65 |
+
print_message(f"#> Training with {round(args.sample * 100.0, 1)}% of *all* embeddings (provided --sample).")
|
66 |
+
samples_paths = parts_paths
|
67 |
+
|
68 |
+
num_parts_per_slice = math.ceil(len(parts) / args.slices)
|
69 |
+
|
70 |
+
for slice_idx, part_offset in enumerate(range(0, len(parts), num_parts_per_slice)):
|
71 |
+
part_endpos = min(part_offset + num_parts_per_slice, len(parts))
|
72 |
+
|
73 |
+
slice_parts_paths = parts_paths[part_offset:part_endpos]
|
74 |
+
slice_samples_paths = samples_paths[part_offset:part_endpos]
|
75 |
+
|
76 |
+
if args.slices == 1:
|
77 |
+
faiss_index_name = get_faiss_index_name(args)
|
78 |
+
else:
|
79 |
+
faiss_index_name = get_faiss_index_name(args, offset=part_offset, endpos=part_endpos)
|
80 |
+
|
81 |
+
output_path = os.path.join(args.index_path, faiss_index_name)
|
82 |
+
print_message(f"#> Processing slice #{slice_idx+1} of {args.slices} (range {part_offset}..{part_endpos}).")
|
83 |
+
print_message(f"#> Will write to {output_path}.")
|
84 |
+
|
85 |
+
assert not os.path.exists(output_path), output_path
|
86 |
+
|
87 |
+
index = prepare_faiss_index(slice_samples_paths, args.partitions, args.sample)
|
88 |
+
|
89 |
+
loaded_parts = queue.Queue(maxsize=1)
|
90 |
+
|
91 |
+
def _loader_thread(thread_parts_paths):
|
92 |
+
for filenames in grouper(thread_parts_paths, SPAN, fillvalue=None):
|
93 |
+
sub_collection = [load_index_part(filename) for filename in filenames if filename is not None]
|
94 |
+
sub_collection = torch.cat(sub_collection)
|
95 |
+
sub_collection = sub_collection.float().numpy()
|
96 |
+
loaded_parts.put(sub_collection)
|
97 |
+
|
98 |
+
thread = threading.Thread(target=_loader_thread, args=(slice_parts_paths,))
|
99 |
+
thread.start()
|
100 |
+
|
101 |
+
print_message("#> Indexing the vectors...")
|
102 |
+
|
103 |
+
for filenames in grouper(slice_parts_paths, SPAN, fillvalue=None):
|
104 |
+
print_message("#> Loading", filenames, "(from queue)...")
|
105 |
+
sub_collection = loaded_parts.get()
|
106 |
+
|
107 |
+
print_message("#> Processing a sub_collection with shape", sub_collection.shape)
|
108 |
+
index.add(sub_collection)
|
109 |
+
|
110 |
+
print_message("Done indexing!")
|
111 |
+
|
112 |
+
index.save(output_path)
|
113 |
+
|
114 |
+
print_message(f"\n\nDone! All complete (for slice #{slice_idx+1} of {args.slices})!")
|
115 |
+
|
116 |
+
thread.join()
|
colbert/indexing/faiss_index.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import time
|
3 |
+
import math
|
4 |
+
import faiss
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from colbert.indexing.faiss_index_gpu import FaissIndexGPU
|
10 |
+
from colbert.utils.utils import print_message
|
11 |
+
|
12 |
+
|
13 |
+
class FaissIndex():
|
14 |
+
def __init__(self, dim, partitions):
|
15 |
+
self.dim = dim
|
16 |
+
self.partitions = partitions
|
17 |
+
|
18 |
+
self.gpu = FaissIndexGPU()
|
19 |
+
self.quantizer, self.index = self._create_index()
|
20 |
+
self.offset = 0
|
21 |
+
|
22 |
+
def _create_index(self):
|
23 |
+
quantizer = faiss.IndexFlatL2(self.dim) # faiss.IndexHNSWFlat(dim, 32)
|
24 |
+
index = faiss.IndexIVFPQ(quantizer, self.dim, self.partitions, 16, 8)
|
25 |
+
|
26 |
+
return quantizer, index
|
27 |
+
|
28 |
+
def train(self, train_data):
|
29 |
+
print_message(f"#> Training now (using {self.gpu.ngpu} GPUs)...")
|
30 |
+
|
31 |
+
if self.gpu.ngpu > 0:
|
32 |
+
self.gpu.training_initialize(self.index, self.quantizer)
|
33 |
+
|
34 |
+
s = time.time()
|
35 |
+
self.index.train(train_data)
|
36 |
+
print(time.time() - s)
|
37 |
+
|
38 |
+
if self.gpu.ngpu > 0:
|
39 |
+
self.gpu.training_finalize()
|
40 |
+
|
41 |
+
def add(self, data):
|
42 |
+
print_message(f"Add data with shape {data.shape} (offset = {self.offset})..")
|
43 |
+
|
44 |
+
if self.gpu.ngpu > 0 and self.offset == 0:
|
45 |
+
self.gpu.adding_initialize(self.index)
|
46 |
+
|
47 |
+
if self.gpu.ngpu > 0:
|
48 |
+
self.gpu.add(self.index, data, self.offset)
|
49 |
+
else:
|
50 |
+
self.index.add(data)
|
51 |
+
|
52 |
+
self.offset += data.shape[0]
|
53 |
+
|
54 |
+
def save(self, output_path):
|
55 |
+
print_message(f"Writing index to {output_path} ...")
|
56 |
+
|
57 |
+
self.index.nprobe = 10 # just a default
|
58 |
+
faiss.write_index(self.index, output_path)
|
colbert/indexing/faiss_index_gpu.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Heavily based on: https://github.com/facebookresearch/faiss/blob/master/benchs/bench_gpu_1bn.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
import math
|
9 |
+
import faiss
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from colbert.utils.utils import print_message
|
14 |
+
|
15 |
+
|
16 |
+
class FaissIndexGPU():
|
17 |
+
def __init__(self):
|
18 |
+
self.ngpu = faiss.get_num_gpus()
|
19 |
+
|
20 |
+
if self.ngpu == 0:
|
21 |
+
return
|
22 |
+
|
23 |
+
self.tempmem = 1 << 33
|
24 |
+
self.max_add_per_gpu = 1 << 25
|
25 |
+
self.max_add = self.max_add_per_gpu * self.ngpu
|
26 |
+
self.add_batch_size = 65536
|
27 |
+
|
28 |
+
self.gpu_resources = self._prepare_gpu_resources()
|
29 |
+
|
30 |
+
def _prepare_gpu_resources(self):
|
31 |
+
print_message(f"Preparing resources for {self.ngpu} GPUs.")
|
32 |
+
|
33 |
+
gpu_resources = []
|
34 |
+
|
35 |
+
for _ in range(self.ngpu):
|
36 |
+
res = faiss.StandardGpuResources()
|
37 |
+
if self.tempmem >= 0:
|
38 |
+
res.setTempMemory(self.tempmem)
|
39 |
+
gpu_resources.append(res)
|
40 |
+
|
41 |
+
return gpu_resources
|
42 |
+
|
43 |
+
def _make_vres_vdev(self):
|
44 |
+
"""
|
45 |
+
return vectors of device ids and resources useful for gpu_multiple
|
46 |
+
"""
|
47 |
+
|
48 |
+
assert self.ngpu > 0
|
49 |
+
|
50 |
+
vres = faiss.GpuResourcesVector()
|
51 |
+
vdev = faiss.IntVector()
|
52 |
+
|
53 |
+
for i in range(self.ngpu):
|
54 |
+
vdev.push_back(i)
|
55 |
+
vres.push_back(self.gpu_resources[i])
|
56 |
+
|
57 |
+
return vres, vdev
|
58 |
+
|
59 |
+
def training_initialize(self, index, quantizer):
|
60 |
+
"""
|
61 |
+
The index and quantizer should be owned by caller.
|
62 |
+
"""
|
63 |
+
|
64 |
+
assert self.ngpu > 0
|
65 |
+
|
66 |
+
s = time.time()
|
67 |
+
self.index_ivf = faiss.extract_index_ivf(index)
|
68 |
+
self.clustering_index = faiss.index_cpu_to_all_gpus(quantizer)
|
69 |
+
self.index_ivf.clustering_index = self.clustering_index
|
70 |
+
print(time.time() - s)
|
71 |
+
|
72 |
+
def training_finalize(self):
|
73 |
+
assert self.ngpu > 0
|
74 |
+
|
75 |
+
s = time.time()
|
76 |
+
self.index_ivf.clustering_index = faiss.index_gpu_to_cpu(self.index_ivf.clustering_index)
|
77 |
+
print(time.time() - s)
|
78 |
+
|
79 |
+
def adding_initialize(self, index):
|
80 |
+
"""
|
81 |
+
The index should be owned by caller.
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert self.ngpu > 0
|
85 |
+
|
86 |
+
self.co = faiss.GpuMultipleClonerOptions()
|
87 |
+
self.co.useFloat16 = True
|
88 |
+
self.co.useFloat16CoarseQuantizer = False
|
89 |
+
self.co.usePrecomputed = False
|
90 |
+
self.co.indicesOptions = faiss.INDICES_CPU
|
91 |
+
self.co.verbose = True
|
92 |
+
self.co.reserveVecs = self.max_add
|
93 |
+
self.co.shard = True
|
94 |
+
assert self.co.shard_type in (0, 1, 2)
|
95 |
+
|
96 |
+
self.vres, self.vdev = self._make_vres_vdev()
|
97 |
+
self.gpu_index = faiss.index_cpu_to_gpu_multiple(self.vres, self.vdev, index, self.co)
|
98 |
+
|
99 |
+
def add(self, index, data, offset):
|
100 |
+
assert self.ngpu > 0
|
101 |
+
|
102 |
+
t0 = time.time()
|
103 |
+
nb = data.shape[0]
|
104 |
+
|
105 |
+
for i0 in range(0, nb, self.add_batch_size):
|
106 |
+
i1 = min(i0 + self.add_batch_size, nb)
|
107 |
+
xs = data[i0:i1]
|
108 |
+
|
109 |
+
self.gpu_index.add_with_ids(xs, np.arange(offset+i0, offset+i1))
|
110 |
+
|
111 |
+
if self.max_add > 0 and self.gpu_index.ntotal > self.max_add:
|
112 |
+
self._flush_to_cpu(index, nb, offset)
|
113 |
+
|
114 |
+
print('\r%d/%d (%.3f s) ' % (i0, nb, time.time() - t0), end=' ')
|
115 |
+
sys.stdout.flush()
|
116 |
+
|
117 |
+
if self.gpu_index.ntotal > 0:
|
118 |
+
self._flush_to_cpu(index, nb, offset)
|
119 |
+
|
120 |
+
assert index.ntotal == offset+nb, (index.ntotal, offset+nb, offset, nb)
|
121 |
+
print(f"add(.) time: %.3f s \t\t--\t\t index.ntotal = {index.ntotal}" % (time.time() - t0))
|
122 |
+
|
123 |
+
def _flush_to_cpu(self, index, nb, offset):
|
124 |
+
print("Flush indexes to CPU")
|
125 |
+
|
126 |
+
for i in range(self.ngpu):
|
127 |
+
index_src_gpu = faiss.downcast_index(self.gpu_index if self.ngpu == 1 else self.gpu_index.at(i))
|
128 |
+
index_src = faiss.index_gpu_to_cpu(index_src_gpu)
|
129 |
+
|
130 |
+
index_src.copy_subset_to(index, 0, offset, offset+nb)
|
131 |
+
index_src_gpu.reset()
|
132 |
+
index_src_gpu.reserveMemory(self.max_add)
|
133 |
+
|
134 |
+
if self.ngpu > 1:
|
135 |
+
try:
|
136 |
+
self.gpu_index.sync_with_shard_indexes()
|
137 |
+
except:
|
138 |
+
self.gpu_index.syncWithSubIndexes()
|
colbert/indexing/index_manager.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from colbert.utils.utils import print_message
|
6 |
+
|
7 |
+
|
8 |
+
class IndexManager():
|
9 |
+
def __init__(self, dim):
|
10 |
+
self.dim = dim
|
11 |
+
|
12 |
+
def save(self, tensor, path_prefix):
|
13 |
+
torch.save(tensor, path_prefix)
|
14 |
+
|
15 |
+
|
16 |
+
def load_index_part(filename, verbose=True):
|
17 |
+
part = torch.load(filename)
|
18 |
+
|
19 |
+
if type(part) == list: # for backward compatibility
|
20 |
+
part = torch.cat(part)
|
21 |
+
|
22 |
+
return part
|
colbert/indexing/loaders.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import ujson
|
4 |
+
|
5 |
+
from math import ceil
|
6 |
+
from itertools import accumulate
|
7 |
+
from colbert.utils.utils import print_message
|
8 |
+
|
9 |
+
|
10 |
+
def get_parts(directory):
|
11 |
+
extension = '.pt'
|
12 |
+
|
13 |
+
parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
|
14 |
+
if filename.endswith(extension)])
|
15 |
+
|
16 |
+
assert list(range(len(parts))) == parts, parts
|
17 |
+
|
18 |
+
# Integer-sortedness matters.
|
19 |
+
parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
|
20 |
+
samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]
|
21 |
+
|
22 |
+
return parts, parts_paths, samples_paths
|
23 |
+
|
24 |
+
|
25 |
+
def load_doclens(directory, flatten=True):
|
26 |
+
parts, _, _ = get_parts(directory)
|
27 |
+
|
28 |
+
doclens_filenames = [os.path.join(directory, 'doclens.{}.json'.format(filename)) for filename in parts]
|
29 |
+
all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames]
|
30 |
+
|
31 |
+
if flatten:
|
32 |
+
all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens]
|
33 |
+
|
34 |
+
return all_doclens
|
colbert/modeling/__init__.py
ADDED
File without changes
|
colbert/modeling/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (129 Bytes). View file
|
|
colbert/modeling/__pycache__/colbert.cpython-37.pyc
ADDED
Binary file (3.33 kB). View file
|
|
colbert/modeling/__pycache__/inference.cpython-37.pyc
ADDED
Binary file (3.81 kB). View file
|
|
colbert/modeling/colbert.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import string
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
|
6 |
+
from colbert.parameters import DEVICE
|
7 |
+
|
8 |
+
|
9 |
+
class ColBERT(BertPreTrainedModel):
|
10 |
+
def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):
|
11 |
+
|
12 |
+
super(ColBERT, self).__init__(config)
|
13 |
+
|
14 |
+
self.query_maxlen = query_maxlen
|
15 |
+
self.doc_maxlen = doc_maxlen
|
16 |
+
self.similarity_metric = similarity_metric
|
17 |
+
self.dim = dim
|
18 |
+
|
19 |
+
self.mask_punctuation = mask_punctuation
|
20 |
+
self.skiplist = {}
|
21 |
+
|
22 |
+
if self.mask_punctuation:
|
23 |
+
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
|
24 |
+
self.skiplist = {w: True
|
25 |
+
for symbol in string.punctuation
|
26 |
+
for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}
|
27 |
+
|
28 |
+
self.bert = BertModel(config)
|
29 |
+
self.linear = nn.Linear(config.hidden_size, dim * 2, bias=False)
|
30 |
+
|
31 |
+
self.init_weights()
|
32 |
+
|
33 |
+
def forward(self, Q, D):
|
34 |
+
return self.score(self.query(*Q), self.doc(*D))
|
35 |
+
|
36 |
+
def query(self, input_ids, attention_mask):
|
37 |
+
input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
|
38 |
+
Q = self.bert(input_ids, attention_mask=attention_mask)[0]
|
39 |
+
Q = self.linear(Q)
|
40 |
+
Q = Q.split(int(Q.size(2)/2),2)
|
41 |
+
Q = torch.cat(Q,1)
|
42 |
+
|
43 |
+
return torch.nn.functional.normalize(Q, p=2, dim=2)
|
44 |
+
|
45 |
+
def doc(self, input_ids, attention_mask, keep_dims=True):
|
46 |
+
input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
|
47 |
+
D = self.bert(input_ids, attention_mask=attention_mask)[0]
|
48 |
+
D = self.linear(D)
|
49 |
+
D = D.split(int(D.size(2)/2),2)
|
50 |
+
D = torch.cat(D,1)
|
51 |
+
|
52 |
+
mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
|
53 |
+
mask = torch.cat(2*[mask],1)
|
54 |
+
D = D * mask
|
55 |
+
|
56 |
+
D = torch.nn.functional.normalize(D, p=2, dim=2)
|
57 |
+
|
58 |
+
if not keep_dims:
|
59 |
+
D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
|
60 |
+
D = [d[mask[idx]] for idx, d in enumerate(D)]
|
61 |
+
|
62 |
+
return D
|
63 |
+
|
64 |
+
def score(self, Q, D):
|
65 |
+
if self.similarity_metric == 'cosine':
|
66 |
+
return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)
|
67 |
+
|
68 |
+
assert self.similarity_metric == 'l2'
|
69 |
+
return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
|
70 |
+
|
71 |
+
def mask(self, input_ids):
|
72 |
+
mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
|
73 |
+
return mask
|
colbert/modeling/inference.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from colbert.modeling.colbert import ColBERT
|
4 |
+
from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer
|
5 |
+
from colbert.utils.amp import MixedPrecisionManager
|
6 |
+
from colbert.parameters import DEVICE
|
7 |
+
|
8 |
+
|
9 |
+
class ModelInference():
|
10 |
+
def __init__(self, colbert: ColBERT, amp=False):
|
11 |
+
assert colbert.training is False
|
12 |
+
|
13 |
+
self.colbert = colbert
|
14 |
+
self.query_tokenizer = QueryTokenizer(colbert.query_maxlen)
|
15 |
+
self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen)
|
16 |
+
|
17 |
+
self.amp_manager = MixedPrecisionManager(amp)
|
18 |
+
|
19 |
+
def query(self, *args, to_cpu=False, **kw_args):
|
20 |
+
with torch.no_grad():
|
21 |
+
with self.amp_manager.context():
|
22 |
+
Q = self.colbert.query(*args, **kw_args)
|
23 |
+
return Q.cpu() if to_cpu else Q
|
24 |
+
|
25 |
+
def doc(self, *args, to_cpu=False, **kw_args):
|
26 |
+
with torch.no_grad():
|
27 |
+
with self.amp_manager.context():
|
28 |
+
D = self.colbert.doc(*args, **kw_args)
|
29 |
+
return D.cpu() if to_cpu else D
|
30 |
+
|
31 |
+
def queryFromText(self, queries, bsize=None, to_cpu=False):
|
32 |
+
if bsize:
|
33 |
+
batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
|
34 |
+
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
35 |
+
return torch.cat(batches)
|
36 |
+
|
37 |
+
input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
|
38 |
+
return self.query(input_ids, attention_mask)
|
39 |
+
|
40 |
+
def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
|
41 |
+
if bsize:
|
42 |
+
batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)
|
43 |
+
|
44 |
+
batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
|
45 |
+
for input_ids, attention_mask in batches]
|
46 |
+
|
47 |
+
if keep_dims:
|
48 |
+
D = _stack_3D_tensors(batches)
|
49 |
+
return D[reverse_indices]
|
50 |
+
|
51 |
+
D = [d for batch in batches for d in batch]
|
52 |
+
return [D[idx] for idx in reverse_indices.tolist()]
|
53 |
+
|
54 |
+
input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
|
55 |
+
return self.doc(input_ids, attention_mask, keep_dims=keep_dims)
|
56 |
+
|
57 |
+
def score(self, Q, D, mask=None, lengths=None, explain=False):
|
58 |
+
if lengths is not None:
|
59 |
+
assert mask is None, "don't supply both mask and lengths"
|
60 |
+
|
61 |
+
mask = torch.arange(D.size(1), device=DEVICE) + 1
|
62 |
+
mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)
|
63 |
+
|
64 |
+
scores = (D @ Q)
|
65 |
+
scores = scores if mask is None else scores * mask.unsqueeze(-1)
|
66 |
+
scores = scores.max(1)
|
67 |
+
|
68 |
+
if explain:
|
69 |
+
assert False, "TODO"
|
70 |
+
|
71 |
+
return scores.values.sum(-1).cpu()
|
72 |
+
|
73 |
+
|
74 |
+
def _stack_3D_tensors(groups):
|
75 |
+
bsize = sum([x.size(0) for x in groups])
|
76 |
+
maxlen = max([x.size(1) for x in groups])
|
77 |
+
hdim = groups[0].size(2)
|
78 |
+
|
79 |
+
output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)
|
80 |
+
|
81 |
+
offset = 0
|
82 |
+
for x in groups:
|
83 |
+
endpos = offset + x.size(0)
|
84 |
+
output[offset:endpos, :x.size(1)] = x
|
85 |
+
offset = endpos
|
86 |
+
|
87 |
+
return output
|
colbert/modeling/tokenization/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from colbert.modeling.tokenization.query_tokenization import *
|
2 |
+
from colbert.modeling.tokenization.doc_tokenization import *
|
3 |
+
from colbert.modeling.tokenization.utils import tensorize_triples
|
colbert/modeling/tokenization/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (342 Bytes). View file
|
|
colbert/modeling/tokenization/__pycache__/doc_tokenization.cpython-37.pyc
ADDED
Binary file (2.62 kB). View file
|
|
colbert/modeling/tokenization/__pycache__/query_tokenization.cpython-37.pyc
ADDED
Binary file (2.75 kB). View file
|
|
colbert/modeling/tokenization/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (1.58 kB). View file
|
|
colbert/modeling/tokenization/doc_tokenization.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from transformers import BertTokenizerFast
|
4 |
+
from colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length
|
5 |
+
|
6 |
+
|
7 |
+
class DocTokenizer():
|
8 |
+
def __init__(self, doc_maxlen):
|
9 |
+
self.tok = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
|
10 |
+
self.doc_maxlen = doc_maxlen
|
11 |
+
|
12 |
+
self.D_marker_token, self.D_marker_token_id = '[D]', self.tok.convert_tokens_to_ids('[unused1]')
|
13 |
+
self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
|
14 |
+
self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
|
15 |
+
|
16 |
+
assert self.D_marker_token_id == 1
|
17 |
+
|
18 |
+
def tokenize(self, batch_text, add_special_tokens=False):
|
19 |
+
assert type(batch_text) in [list, tuple], (type(batch_text))
|
20 |
+
|
21 |
+
tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]
|
22 |
+
|
23 |
+
if not add_special_tokens:
|
24 |
+
return tokens
|
25 |
+
|
26 |
+
prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
|
27 |
+
tokens = [prefix + lst + suffix for lst in tokens]
|
28 |
+
|
29 |
+
return tokens
|
30 |
+
|
31 |
+
def encode(self, batch_text, add_special_tokens=False):
|
32 |
+
assert type(batch_text) in [list, tuple], (type(batch_text))
|
33 |
+
|
34 |
+
ids = self.tok(batch_text, add_special_tokens=False)['input_ids']
|
35 |
+
|
36 |
+
if not add_special_tokens:
|
37 |
+
return ids
|
38 |
+
|
39 |
+
prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
|
40 |
+
ids = [prefix + lst + suffix for lst in ids]
|
41 |
+
|
42 |
+
return ids
|
43 |
+
|
44 |
+
def tensorize(self, batch_text, bsize=None):
|
45 |
+
assert type(batch_text) in [list, tuple], (type(batch_text))
|
46 |
+
|
47 |
+
# add placehold for the [D] marker
|
48 |
+
batch_text = ['. ' + x for x in batch_text]
|
49 |
+
|
50 |
+
obj = self.tok(batch_text, padding='longest', truncation='longest_first',
|
51 |
+
return_tensors='pt', max_length=self.doc_maxlen)
|
52 |
+
|
53 |
+
ids, mask = obj['input_ids'], obj['attention_mask']
|
54 |
+
|
55 |
+
# postprocess for the [D] marker
|
56 |
+
ids[:, 1] = self.D_marker_token_id
|
57 |
+
|
58 |
+
if bsize:
|
59 |
+
ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
|
60 |
+
batches = _split_into_batches(ids, mask, bsize)
|
61 |
+
return batches, reverse_indices
|
62 |
+
|
63 |
+
return ids, mask
|
colbert/modeling/tokenization/query_tokenization.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from transformers import BertTokenizerFast
|
4 |
+
from colbert.modeling.tokenization.utils import _split_into_batches
|
5 |
+
|
6 |
+
|
7 |
+
class QueryTokenizer():
|
8 |
+
def __init__(self, query_maxlen):
|
9 |
+
self.tok = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
|
10 |
+
self.query_maxlen = query_maxlen
|
11 |
+
|
12 |
+
self.Q_marker_token, self.Q_marker_token_id = '[Q]', self.tok.convert_tokens_to_ids('[unused0]')
|
13 |
+
self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
|
14 |
+
self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
|
15 |
+
self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id
|
16 |
+
|
17 |
+
assert self.Q_marker_token_id == 100 and self.mask_token_id == 103
|
18 |
+
|
19 |
+
def tokenize(self, batch_text, add_special_tokens=False):
|
20 |
+
assert type(batch_text) in [list, tuple], (type(batch_text))
|
21 |
+
|
22 |
+
tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]
|
23 |
+
|
24 |
+
if not add_special_tokens:
|
25 |
+
return tokens
|
26 |
+
|
27 |
+
prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
|
28 |
+
tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]
|
29 |
+
|
30 |
+
return tokens
|
31 |
+
|
32 |
+
def encode(self, batch_text, add_special_tokens=False):
|
33 |
+
assert type(batch_text) in [list, tuple], (type(batch_text))
|
34 |
+
|
35 |
+
ids = self.tok(batch_text, add_special_tokens=False)['input_ids']
|
36 |
+
|
37 |
+
if not add_special_tokens:
|
38 |
+
return ids
|
39 |
+
|
40 |
+
prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
|
41 |
+
ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]
|
42 |
+
|
43 |
+
return ids
|
44 |
+
|
45 |
+
def tensorize(self, batch_text, bsize=None):
|
46 |
+
assert type(batch_text) in [list, tuple], (type(batch_text))
|
47 |
+
|
48 |
+
# add placehold for the [Q] marker
|
49 |
+
batch_text = ['. ' + x for x in batch_text]
|
50 |
+
|
51 |
+
obj = self.tok(batch_text, padding='max_length', truncation=True,
|
52 |
+
return_tensors='pt', max_length=self.query_maxlen)
|
53 |
+
|
54 |
+
ids, mask = obj['input_ids'], obj['attention_mask']
|
55 |
+
|
56 |
+
# postprocess for the [Q] marker and the [MASK] augmentation
|
57 |
+
ids[:, 1] = self.Q_marker_token_id
|
58 |
+
ids[ids == 0] = self.mask_token_id
|
59 |
+
|
60 |
+
if bsize:
|
61 |
+
batches = _split_into_batches(ids, mask, bsize)
|
62 |
+
return batches
|
63 |
+
|
64 |
+
return ids, mask
|
colbert/modeling/tokenization/utils.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
|
5 |
+
assert len(queries) == len(positives) == len(negatives)
|
6 |
+
assert bsize is None or len(queries) % bsize == 0
|
7 |
+
|
8 |
+
N = len(queries)
|
9 |
+
Q_ids, Q_mask = query_tokenizer.tensorize(queries)
|
10 |
+
D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
|
11 |
+
D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)
|
12 |
+
|
13 |
+
# Compute max among {length of i^th positive, length of i^th negative} for i \in N
|
14 |
+
maxlens = D_mask.sum(-1).max(0).values
|
15 |
+
|
16 |
+
# Sort by maxlens
|
17 |
+
indices = maxlens.sort().indices
|
18 |
+
Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
|
19 |
+
D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]
|
20 |
+
|
21 |
+
(positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask
|
22 |
+
|
23 |
+
query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
|
24 |
+
positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
|
25 |
+
negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)
|
26 |
+
|
27 |
+
batches = []
|
28 |
+
for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
|
29 |
+
Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
|
30 |
+
D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
|
31 |
+
batches.append((Q, D))
|
32 |
+
|
33 |
+
return batches
|
34 |
+
|
35 |
+
|
36 |
+
def _sort_by_length(ids, mask, bsize):
|
37 |
+
if ids.size(0) <= bsize:
|
38 |
+
return ids, mask, torch.arange(ids.size(0))
|
39 |
+
|
40 |
+
indices = mask.sum(-1).sort().indices
|
41 |
+
reverse_indices = indices.sort().indices
|
42 |
+
|
43 |
+
return ids[indices], mask[indices], reverse_indices
|
44 |
+
|
45 |
+
|
46 |
+
def _split_into_batches(ids, mask, bsize):
|
47 |
+
batches = []
|
48 |
+
for offset in range(0, ids.size(0), bsize):
|
49 |
+
batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))
|
50 |
+
|
51 |
+
return batches
|