Spaces:
Runtime error
Runtime error
File size: 7,689 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import torch
from colbert.utils.utils import flatten, print_message
from colbert.indexing.loaders import load_doclens
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.search.strided_tensor import StridedTensor
from colbert.search.candidate_generation import CandidateGeneration
from .index_loader import IndexLoader
from colbert.modeling.colbert import colbert_score, colbert_score_packed, colbert_score_reduce
from math import ceil
import os
import pathlib
from torch.utils.cpp_extension import load
class IndexScorer(IndexLoader, CandidateGeneration):
def __init__(self, index_path, use_gpu=True):
super().__init__(index_path=index_path, use_gpu=use_gpu)
IndexScorer.try_load_torch_extensions(use_gpu)
self.embeddings_strided = ResidualEmbeddingsStrided(self.codec, self.embeddings, self.doclens)
@classmethod
def try_load_torch_extensions(cls, use_gpu):
if hasattr(cls, "loaded_extensions") or use_gpu:
return
print_message(f"Loading filter_pids_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...")
filter_pids_cpp = load(
name="filter_pids_cpp",
sources=[
os.path.join(
pathlib.Path(__file__).parent.resolve(), "filter_pids.cpp"
),
],
extra_cflags=["-O3"],
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
)
cls.filter_pids = filter_pids_cpp.filter_pids_cpp
print_message(f"Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...")
decompress_residuals_cpp = load(
name="decompress_residuals_cpp",
sources=[
os.path.join(
pathlib.Path(__file__).parent.resolve(), "decompress_residuals.cpp"
),
],
extra_cflags=["-O3"],
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
)
cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp
cls.loaded_extensions = True
def lookup_eids(self, embedding_ids, codes=None, out_device='cuda'):
return self.embeddings_strided.lookup_eids(embedding_ids, codes=codes, out_device=out_device)
def lookup_pids(self, passage_ids, out_device='cuda', return_mask=False):
return self.embeddings_strided.lookup_pids(passage_ids, out_device)
def retrieve(self, config, Q):
Q = Q[:, :config.query_maxlen] # NOTE: Candidate generation uses only the query tokens
embedding_ids, centroid_scores = self.generate_candidates(config, Q)
return embedding_ids, centroid_scores
def embedding_ids_to_pids(self, embedding_ids):
all_pids = torch.unique(self.emb2pid[embedding_ids.long()].cuda(), sorted=False)
return all_pids
def rank(self, config, Q, filter_fn=None):
with torch.inference_mode():
pids, centroid_scores = self.retrieve(config, Q)
if filter_fn is not None:
pids = filter_fn(pids)
scores, pids = self.score_pids(config, Q, pids, centroid_scores)
scores_sorter = scores.sort(descending=True)
pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
return pids, scores
def score_pids(self, config, Q, pids, centroid_scores):
"""
Always supply a flat list or tensor for `pids`.
Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
If Q.size(0) is 1, the matrix will be compared with all passages.
Otherwise, each query matrix will be compared against the *aligned* passage.
"""
# TODO: Remove batching?
batch_size = 2 ** 20
if self.use_gpu:
centroid_scores = centroid_scores.cuda()
idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
if self.use_gpu:
approx_scores = []
# Filter docs using pruned centroid scores
for i in range(0, ceil(len(pids) / batch_size)):
pids_ = pids[i * batch_size : (i+1) * batch_size]
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids_)
idx_ = idx[codes_packed.long()]
pruned_codes_strided = StridedTensor(idx_, codes_lengths, use_gpu=self.use_gpu)
pruned_codes_padded, pruned_codes_mask = pruned_codes_strided.as_padded_tensor()
pruned_codes_lengths = (pruned_codes_padded * pruned_codes_mask).sum(dim=1)
codes_packed_ = codes_packed[idx_]
approx_scores_ = centroid_scores[codes_packed_.long()]
if approx_scores_.shape[0] == 0:
approx_scores.append(torch.zeros((len(pids_),), dtype=approx_scores_.dtype))
continue
approx_scores_strided = StridedTensor(approx_scores_, pruned_codes_lengths, use_gpu=self.use_gpu)
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
approx_scores_ = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config)
approx_scores.append(approx_scores_)
approx_scores = torch.cat(approx_scores, dim=0)
if config.ndocs < len(approx_scores):
pids = pids[torch.topk(approx_scores, k=config.ndocs).indices]
# Filter docs using full centroid scores
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids)
approx_scores = centroid_scores[codes_packed.long()]
approx_scores_strided = StridedTensor(approx_scores, codes_lengths, use_gpu=self.use_gpu)
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
approx_scores = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config)
if config.ndocs // 4 < len(approx_scores):
pids = pids[torch.topk(approx_scores, k=(config.ndocs // 4)).indices]
else:
pids = IndexScorer.filter_pids(
pids, centroid_scores, self.embeddings.codes, self.doclens,
self.embeddings_strided.codes_strided.offsets, idx, config.ndocs
)
# Rank final list of docs using full approximate embeddings (including residuals)
if self.use_gpu:
D_packed, D_mask = self.lookup_pids(pids)
else:
D_packed = IndexScorer.decompress_residuals(
pids,
self.doclens,
self.embeddings_strided.codes_strided.offsets,
self.codec.bucket_weights,
self.codec.reversed_bit_map,
self.codec.decompression_lookup_table,
self.embeddings.residuals,
self.embeddings.codes,
self.codec.centroids,
self.codec.dim,
self.codec.nbits
)
D_packed = torch.nn.functional.normalize(D_packed.to(torch.float32), p=2, dim=-1)
D_mask = self.doclens[pids.long()]
if Q.size(0) == 1:
return colbert_score_packed(Q, D_packed, D_mask, config), pids
D_strided = StridedTensor(D_packed, D_mask, use_gpu=self.use_gpu)
D_padded, D_lengths = D_strided.as_padded_tensor()
return colbert_score(Q, D_padded, D_lengths, config), pids
|