Spaces:
Runtime error
Runtime error
import torch | |
from colbert.utils.utils import flatten, print_message | |
from colbert.indexing.loaders import load_doclens | |
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided | |
from colbert.search.strided_tensor import StridedTensor | |
from colbert.search.candidate_generation import CandidateGeneration | |
from .index_loader import IndexLoader | |
from colbert.modeling.colbert import colbert_score, colbert_score_packed, colbert_score_reduce | |
from math import ceil | |
import os | |
import pathlib | |
from torch.utils.cpp_extension import load | |
class IndexScorer(IndexLoader, CandidateGeneration): | |
def __init__(self, index_path, use_gpu=True): | |
super().__init__(index_path=index_path, use_gpu=use_gpu) | |
IndexScorer.try_load_torch_extensions(use_gpu) | |
self.embeddings_strided = ResidualEmbeddingsStrided(self.codec, self.embeddings, self.doclens) | |
def try_load_torch_extensions(cls, use_gpu): | |
if hasattr(cls, "loaded_extensions") or use_gpu: | |
return | |
print_message(f"Loading filter_pids_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...") | |
filter_pids_cpp = load( | |
name="filter_pids_cpp", | |
sources=[ | |
os.path.join( | |
pathlib.Path(__file__).parent.resolve(), "filter_pids.cpp" | |
), | |
], | |
extra_cflags=["-O3"], | |
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", | |
) | |
cls.filter_pids = filter_pids_cpp.filter_pids_cpp | |
print_message(f"Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...") | |
decompress_residuals_cpp = load( | |
name="decompress_residuals_cpp", | |
sources=[ | |
os.path.join( | |
pathlib.Path(__file__).parent.resolve(), "decompress_residuals.cpp" | |
), | |
], | |
extra_cflags=["-O3"], | |
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", | |
) | |
cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp | |
cls.loaded_extensions = True | |
def lookup_eids(self, embedding_ids, codes=None, out_device='cuda'): | |
return self.embeddings_strided.lookup_eids(embedding_ids, codes=codes, out_device=out_device) | |
def lookup_pids(self, passage_ids, out_device='cuda', return_mask=False): | |
return self.embeddings_strided.lookup_pids(passage_ids, out_device) | |
def retrieve(self, config, Q): | |
Q = Q[:, :config.query_maxlen] # NOTE: Candidate generation uses only the query tokens | |
embedding_ids, centroid_scores = self.generate_candidates(config, Q) | |
return embedding_ids, centroid_scores | |
def embedding_ids_to_pids(self, embedding_ids): | |
all_pids = torch.unique(self.emb2pid[embedding_ids.long()].cuda(), sorted=False) | |
return all_pids | |
def rank(self, config, Q, filter_fn=None): | |
with torch.inference_mode(): | |
pids, centroid_scores = self.retrieve(config, Q) | |
if filter_fn is not None: | |
pids = filter_fn(pids) | |
scores, pids = self.score_pids(config, Q, pids, centroid_scores) | |
scores_sorter = scores.sort(descending=True) | |
pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist() | |
return pids, scores | |
def score_pids(self, config, Q, pids, centroid_scores): | |
""" | |
Always supply a flat list or tensor for `pids`. | |
Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim). | |
If Q.size(0) is 1, the matrix will be compared with all passages. | |
Otherwise, each query matrix will be compared against the *aligned* passage. | |
""" | |
# TODO: Remove batching? | |
batch_size = 2 ** 20 | |
if self.use_gpu: | |
centroid_scores = centroid_scores.cuda() | |
idx = centroid_scores.max(-1).values >= config.centroid_score_threshold | |
if self.use_gpu: | |
approx_scores = [] | |
# Filter docs using pruned centroid scores | |
for i in range(0, ceil(len(pids) / batch_size)): | |
pids_ = pids[i * batch_size : (i+1) * batch_size] | |
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids_) | |
idx_ = idx[codes_packed.long()] | |
pruned_codes_strided = StridedTensor(idx_, codes_lengths, use_gpu=self.use_gpu) | |
pruned_codes_padded, pruned_codes_mask = pruned_codes_strided.as_padded_tensor() | |
pruned_codes_lengths = (pruned_codes_padded * pruned_codes_mask).sum(dim=1) | |
codes_packed_ = codes_packed[idx_] | |
approx_scores_ = centroid_scores[codes_packed_.long()] | |
if approx_scores_.shape[0] == 0: | |
approx_scores.append(torch.zeros((len(pids_),), dtype=approx_scores_.dtype)) | |
continue | |
approx_scores_strided = StridedTensor(approx_scores_, pruned_codes_lengths, use_gpu=self.use_gpu) | |
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor() | |
approx_scores_ = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config) | |
approx_scores.append(approx_scores_) | |
approx_scores = torch.cat(approx_scores, dim=0) | |
if config.ndocs < len(approx_scores): | |
pids = pids[torch.topk(approx_scores, k=config.ndocs).indices] | |
# Filter docs using full centroid scores | |
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids) | |
approx_scores = centroid_scores[codes_packed.long()] | |
approx_scores_strided = StridedTensor(approx_scores, codes_lengths, use_gpu=self.use_gpu) | |
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor() | |
approx_scores = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config) | |
if config.ndocs // 4 < len(approx_scores): | |
pids = pids[torch.topk(approx_scores, k=(config.ndocs // 4)).indices] | |
else: | |
pids = IndexScorer.filter_pids( | |
pids, centroid_scores, self.embeddings.codes, self.doclens, | |
self.embeddings_strided.codes_strided.offsets, idx, config.ndocs | |
) | |
# Rank final list of docs using full approximate embeddings (including residuals) | |
if self.use_gpu: | |
D_packed, D_mask = self.lookup_pids(pids) | |
else: | |
D_packed = IndexScorer.decompress_residuals( | |
pids, | |
self.doclens, | |
self.embeddings_strided.codes_strided.offsets, | |
self.codec.bucket_weights, | |
self.codec.reversed_bit_map, | |
self.codec.decompression_lookup_table, | |
self.embeddings.residuals, | |
self.embeddings.codes, | |
self.codec.centroids, | |
self.codec.dim, | |
self.codec.nbits | |
) | |
D_packed = torch.nn.functional.normalize(D_packed.to(torch.float32), p=2, dim=-1) | |
D_mask = self.doclens[pids.long()] | |
if Q.size(0) == 1: | |
return colbert_score_packed(Q, D_packed, D_mask, config), pids | |
D_strided = StridedTensor(D_packed, D_mask, use_gpu=self.use_gpu) | |
D_padded, D_lengths = D_strided.as_padded_tensor() | |
return colbert_score(Q, D_padded, D_lengths, config), pids | |