ir_chinese_medqa / colbert /search /index_storage.py
欧卫
'add_app_files'
58627fa
raw
history blame
7.69 kB
import torch
from colbert.utils.utils import flatten, print_message
from colbert.indexing.loaders import load_doclens
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.search.strided_tensor import StridedTensor
from colbert.search.candidate_generation import CandidateGeneration
from .index_loader import IndexLoader
from colbert.modeling.colbert import colbert_score, colbert_score_packed, colbert_score_reduce
from math import ceil
import os
import pathlib
from torch.utils.cpp_extension import load
class IndexScorer(IndexLoader, CandidateGeneration):
def __init__(self, index_path, use_gpu=True):
super().__init__(index_path=index_path, use_gpu=use_gpu)
IndexScorer.try_load_torch_extensions(use_gpu)
self.embeddings_strided = ResidualEmbeddingsStrided(self.codec, self.embeddings, self.doclens)
@classmethod
def try_load_torch_extensions(cls, use_gpu):
if hasattr(cls, "loaded_extensions") or use_gpu:
return
print_message(f"Loading filter_pids_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...")
filter_pids_cpp = load(
name="filter_pids_cpp",
sources=[
os.path.join(
pathlib.Path(__file__).parent.resolve(), "filter_pids.cpp"
),
],
extra_cflags=["-O3"],
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
)
cls.filter_pids = filter_pids_cpp.filter_pids_cpp
print_message(f"Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...")
decompress_residuals_cpp = load(
name="decompress_residuals_cpp",
sources=[
os.path.join(
pathlib.Path(__file__).parent.resolve(), "decompress_residuals.cpp"
),
],
extra_cflags=["-O3"],
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
)
cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp
cls.loaded_extensions = True
def lookup_eids(self, embedding_ids, codes=None, out_device='cuda'):
return self.embeddings_strided.lookup_eids(embedding_ids, codes=codes, out_device=out_device)
def lookup_pids(self, passage_ids, out_device='cuda', return_mask=False):
return self.embeddings_strided.lookup_pids(passage_ids, out_device)
def retrieve(self, config, Q):
Q = Q[:, :config.query_maxlen] # NOTE: Candidate generation uses only the query tokens
embedding_ids, centroid_scores = self.generate_candidates(config, Q)
return embedding_ids, centroid_scores
def embedding_ids_to_pids(self, embedding_ids):
all_pids = torch.unique(self.emb2pid[embedding_ids.long()].cuda(), sorted=False)
return all_pids
def rank(self, config, Q, filter_fn=None):
with torch.inference_mode():
pids, centroid_scores = self.retrieve(config, Q)
if filter_fn is not None:
pids = filter_fn(pids)
scores, pids = self.score_pids(config, Q, pids, centroid_scores)
scores_sorter = scores.sort(descending=True)
pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
return pids, scores
def score_pids(self, config, Q, pids, centroid_scores):
"""
Always supply a flat list or tensor for `pids`.
Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
If Q.size(0) is 1, the matrix will be compared with all passages.
Otherwise, each query matrix will be compared against the *aligned* passage.
"""
# TODO: Remove batching?
batch_size = 2 ** 20
if self.use_gpu:
centroid_scores = centroid_scores.cuda()
idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
if self.use_gpu:
approx_scores = []
# Filter docs using pruned centroid scores
for i in range(0, ceil(len(pids) / batch_size)):
pids_ = pids[i * batch_size : (i+1) * batch_size]
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids_)
idx_ = idx[codes_packed.long()]
pruned_codes_strided = StridedTensor(idx_, codes_lengths, use_gpu=self.use_gpu)
pruned_codes_padded, pruned_codes_mask = pruned_codes_strided.as_padded_tensor()
pruned_codes_lengths = (pruned_codes_padded * pruned_codes_mask).sum(dim=1)
codes_packed_ = codes_packed[idx_]
approx_scores_ = centroid_scores[codes_packed_.long()]
if approx_scores_.shape[0] == 0:
approx_scores.append(torch.zeros((len(pids_),), dtype=approx_scores_.dtype))
continue
approx_scores_strided = StridedTensor(approx_scores_, pruned_codes_lengths, use_gpu=self.use_gpu)
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
approx_scores_ = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config)
approx_scores.append(approx_scores_)
approx_scores = torch.cat(approx_scores, dim=0)
if config.ndocs < len(approx_scores):
pids = pids[torch.topk(approx_scores, k=config.ndocs).indices]
# Filter docs using full centroid scores
codes_packed, codes_lengths = self.embeddings_strided.lookup_codes(pids)
approx_scores = centroid_scores[codes_packed.long()]
approx_scores_strided = StridedTensor(approx_scores, codes_lengths, use_gpu=self.use_gpu)
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
approx_scores = colbert_score_reduce(approx_scores_padded, approx_scores_mask, config)
if config.ndocs // 4 < len(approx_scores):
pids = pids[torch.topk(approx_scores, k=(config.ndocs // 4)).indices]
else:
pids = IndexScorer.filter_pids(
pids, centroid_scores, self.embeddings.codes, self.doclens,
self.embeddings_strided.codes_strided.offsets, idx, config.ndocs
)
# Rank final list of docs using full approximate embeddings (including residuals)
if self.use_gpu:
D_packed, D_mask = self.lookup_pids(pids)
else:
D_packed = IndexScorer.decompress_residuals(
pids,
self.doclens,
self.embeddings_strided.codes_strided.offsets,
self.codec.bucket_weights,
self.codec.reversed_bit_map,
self.codec.decompression_lookup_table,
self.embeddings.residuals,
self.embeddings.codes,
self.codec.centroids,
self.codec.dim,
self.codec.nbits
)
D_packed = torch.nn.functional.normalize(D_packed.to(torch.float32), p=2, dim=-1)
D_mask = self.doclens[pids.long()]
if Q.size(0) == 1:
return colbert_score_packed(Q, D_packed, D_mask, config), pids
D_strided = StridedTensor(D_packed, D_mask, use_gpu=self.use_gpu)
D_padded, D_lengths = D_strided.as_padded_tensor()
return colbert_score(Q, D_padded, D_lengths, config), pids