Spaces:
Runtime error
Runtime error
from struct import pack | |
import torch | |
from torch._C import device | |
from colbert.utils.utils import flatten, print_message | |
from .strided_tensor_core import StridedTensorCore, _create_mask, _create_view | |
import os | |
import pathlib | |
from torch.utils.cpp_extension import load | |
class StridedTensor(StridedTensorCore): | |
def __init__(self, packed_tensor, lengths, dim=None, use_gpu=True): | |
super().__init__(packed_tensor, lengths, dim=dim, use_gpu=use_gpu) | |
StridedTensor.try_load_torch_extensions(use_gpu) | |
def try_load_torch_extensions(cls, use_gpu): | |
if hasattr(cls, "loaded_extensions") or use_gpu: | |
return | |
print_message(f"Loading segmented_lookup_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...") | |
segmented_lookup_cpp = load( | |
name="segmented_lookup_cpp", | |
sources=[ | |
os.path.join( | |
pathlib.Path(__file__).parent.resolve(), "segmented_lookup.cpp" | |
), | |
], | |
extra_cflags=["-O3"], | |
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", | |
) | |
cls.segmented_lookup = segmented_lookup_cpp.segmented_lookup_cpp | |
cls.loaded_extensions = True | |
def pad_packed(cls, packed_tensor, lengths): | |
assert False, "This seems to be incorrect but I can't see why. Is it the inner_dims in the views?" | |
packed_tensor, lengths = packed_tensor.cuda().contiguous(), lengths.cuda() | |
inner_dims = packed_tensor.size()[1:] | |
stride = lengths.max().item() | |
offsets = torch.cumsum(lengths, dim=0) - lengths[0] | |
padding = torch.zeros(stride, *inner_dims, device=packed_tensor.device, dtype=packed_tensor.dtype) | |
packed_tensor = torch.cat((packed_tensor, padding)) | |
view = _create_view(packed_tensor, stride, inner_dims)[offsets] | |
mask = _create_mask(lengths, stride, like=view) | |
return view, mask | |
def _prepare_lookup(self, pids): | |
if isinstance(pids, list): | |
pids = torch.tensor(pids) | |
assert pids.dim() == 1 | |
if self.use_gpu: | |
pids = pids.cuda() | |
pids = pids.long() | |
lengths = self.lengths[pids] | |
if self.use_gpu: | |
lengths = lengths.cuda() | |
offsets = self.offsets[pids] | |
return pids, lengths, offsets | |
def lookup(self, pids, output='packed'): | |
pids, lengths, offsets = self._prepare_lookup(pids) | |
if self.use_gpu: | |
stride = lengths.max().item() | |
stride = next(s for s in self.strides if stride <= s) | |
tensor = self.views[stride][offsets] | |
if self.use_gpu: | |
tensor = tensor.cuda() | |
mask = _create_mask(lengths, stride, use_gpu=self.use_gpu) | |
if output == 'padded': | |
return tensor, mask | |
assert output == 'packed' | |
tensor = tensor[mask] | |
else: | |
tensor = StridedTensor.segmented_lookup(self.tensor, pids, lengths, offsets) | |
return tensor, lengths | |
def lookup_staggered(self, pids, output='packed'): | |
permute_idxs, unordered_tensors, unordered_lengths, unordered_masks = self.lookup_packed_unordered(pids) | |
output_tensor = torch.empty(permute_idxs.size(0), self.max_stride, *self.inner_dims, | |
dtype=unordered_tensors[0].dtype, device=unordered_tensors[0].device) | |
output_mask = torch.zeros(permute_idxs.size(0), self.max_stride, | |
dtype=unordered_masks[0].dtype, device=unordered_masks[0].device) | |
offset = 0 | |
for tensor, mask in zip(unordered_tensors, unordered_masks): | |
endpos = offset + tensor.size(0) | |
output_tensor[offset:endpos, :tensor.size(1)] = tensor | |
output_mask[offset:endpos, :mask.size(1)] = mask | |
offset = endpos | |
output_mask = output_mask[permute_idxs] | |
output_tensor = output_tensor[permute_idxs] | |
if output == 'padded': | |
return output_tensor, output_mask | |
assert output == 'packed' | |
output_tensor = output_tensor[output_mask] | |
return output_tensor, unordered_lengths[permute_idxs] | |
def lookup_packed_unordered(self, pids): | |
pids, lengths, offsets = self._prepare_lookup(pids) | |
lengths2 = lengths.clone() | |
sentinel = self.strides[-1] + 1 | |
order = torch.arange(pids.size(0), device='cuda' if self.use_gpu else 'cpu') | |
all_orders = [] | |
all_tensors = [] | |
all_lengths = [] | |
all_masks = [] | |
for stride in self.strides: | |
is_shorter = lengths2 <= stride | |
if is_shorter.sum() == 0: | |
continue | |
order_ = order[is_shorter] | |
tensor_, lengths_, mask_ = self._lookup_with_stride(stride, lengths[is_shorter], offsets[is_shorter]) | |
all_orders.append(order_) | |
all_tensors.append(tensor_) | |
all_lengths.append(lengths_) | |
all_masks.append(mask_) | |
lengths2[is_shorter] = sentinel | |
assert lengths2.allclose(torch.tensor([sentinel], device='cuda' if self.use_gpu else 'cpu')) | |
all_orders = torch.cat(all_orders) | |
permute_idxs = torch.sort(all_orders).indices | |
return permute_idxs, all_tensors, torch.cat(all_lengths), all_masks | |
def _lookup_with_stride(self, stride, lengths, offsets): | |
tensor = self.views[stride][offsets] | |
if self.use_gpu: | |
tensor = tensor.cuda() | |
mask = _create_mask(lengths, stride, use_gpu=self.use_gpu) | |
# tensor = tensor[mask] | |
return tensor, lengths, mask | |
if __name__ == '__main__': | |
# lst = [] | |
# for _ in range(10): | |
# lst.append(list(range(random.randint(0, 10)))) | |
# print(lst) | |
# t = StridedTensor.from_nested_list(lst) | |
# print(t.lookup([9])) | |
import os | |
import pickle | |
index_path = '/future/u/okhattab/root/unit/indexes/2021/08/residual.NQ-micro' | |
with open(os.path.join(index_path, "centroid_idx_to_embedding_ids.pickle"), "rb") as f: | |
ivf_list = pickle.load(f) | |
assert len(ivf_list) == max(ivf_list.keys()) + 1 | |
ivf_list = [ivf_list[i] for i in range(len(ivf_list))] | |
for x in ivf_list: | |
assert type(x) is list | |
assert type(x[0]) is int | |
ncentroids = len(ivf_list) | |
ivf = StridedTensor.from_nested_list(ivf_list) | |
import time | |
torch.cuda.synchronize() | |
t = time.time() | |
N = 100 | |
for _ in range(N): | |
probed_centroids = torch.randint(0, ncentroids, size=(32, 8)).flatten() | |
emb_ids, emb_ids_lengths = ivf.lookup(probed_centroids).as_packed_tensor() | |
torch.cuda.synchronize() | |
print((time.time() - t) * 1000 / N, 'ms') | |
print(emb_ids_lengths) | |
slow_result = flatten([ivf_list[idx] for idx in probed_centroids.flatten().tolist()]) | |
print(emb_ids.size(), len(slow_result)) | |
for a, b in zip(slow_result, emb_ids.flatten().tolist()): | |
assert a == b, (a, b) | |
print("#> Done!") | |
print(ivf.lookup(probed_centroids).as_padded_tensor()[0].size()) | |