ir_chinese_medqa / colbert /search /strided_tensor.py
欧卫
'add_app_files'
58627fa
raw
history blame
7.17 kB
from struct import pack
import torch
from torch._C import device
from colbert.utils.utils import flatten, print_message
from .strided_tensor_core import StridedTensorCore, _create_mask, _create_view
import os
import pathlib
from torch.utils.cpp_extension import load
class StridedTensor(StridedTensorCore):
def __init__(self, packed_tensor, lengths, dim=None, use_gpu=True):
super().__init__(packed_tensor, lengths, dim=dim, use_gpu=use_gpu)
StridedTensor.try_load_torch_extensions(use_gpu)
@classmethod
def try_load_torch_extensions(cls, use_gpu):
if hasattr(cls, "loaded_extensions") or use_gpu:
return
print_message(f"Loading segmented_lookup_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...")
segmented_lookup_cpp = load(
name="segmented_lookup_cpp",
sources=[
os.path.join(
pathlib.Path(__file__).parent.resolve(), "segmented_lookup.cpp"
),
],
extra_cflags=["-O3"],
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
)
cls.segmented_lookup = segmented_lookup_cpp.segmented_lookup_cpp
cls.loaded_extensions = True
@classmethod
def pad_packed(cls, packed_tensor, lengths):
assert False, "This seems to be incorrect but I can't see why. Is it the inner_dims in the views?"
packed_tensor, lengths = packed_tensor.cuda().contiguous(), lengths.cuda()
inner_dims = packed_tensor.size()[1:]
stride = lengths.max().item()
offsets = torch.cumsum(lengths, dim=0) - lengths[0]
padding = torch.zeros(stride, *inner_dims, device=packed_tensor.device, dtype=packed_tensor.dtype)
packed_tensor = torch.cat((packed_tensor, padding))
view = _create_view(packed_tensor, stride, inner_dims)[offsets]
mask = _create_mask(lengths, stride, like=view)
return view, mask
def _prepare_lookup(self, pids):
if isinstance(pids, list):
pids = torch.tensor(pids)
assert pids.dim() == 1
if self.use_gpu:
pids = pids.cuda()
pids = pids.long()
lengths = self.lengths[pids]
if self.use_gpu:
lengths = lengths.cuda()
offsets = self.offsets[pids]
return pids, lengths, offsets
def lookup(self, pids, output='packed'):
pids, lengths, offsets = self._prepare_lookup(pids)
if self.use_gpu:
stride = lengths.max().item()
stride = next(s for s in self.strides if stride <= s)
tensor = self.views[stride][offsets]
if self.use_gpu:
tensor = tensor.cuda()
mask = _create_mask(lengths, stride, use_gpu=self.use_gpu)
if output == 'padded':
return tensor, mask
assert output == 'packed'
tensor = tensor[mask]
else:
tensor = StridedTensor.segmented_lookup(self.tensor, pids, lengths, offsets)
return tensor, lengths
def lookup_staggered(self, pids, output='packed'):
permute_idxs, unordered_tensors, unordered_lengths, unordered_masks = self.lookup_packed_unordered(pids)
output_tensor = torch.empty(permute_idxs.size(0), self.max_stride, *self.inner_dims,
dtype=unordered_tensors[0].dtype, device=unordered_tensors[0].device)
output_mask = torch.zeros(permute_idxs.size(0), self.max_stride,
dtype=unordered_masks[0].dtype, device=unordered_masks[0].device)
offset = 0
for tensor, mask in zip(unordered_tensors, unordered_masks):
endpos = offset + tensor.size(0)
output_tensor[offset:endpos, :tensor.size(1)] = tensor
output_mask[offset:endpos, :mask.size(1)] = mask
offset = endpos
output_mask = output_mask[permute_idxs]
output_tensor = output_tensor[permute_idxs]
if output == 'padded':
return output_tensor, output_mask
assert output == 'packed'
output_tensor = output_tensor[output_mask]
return output_tensor, unordered_lengths[permute_idxs]
def lookup_packed_unordered(self, pids):
pids, lengths, offsets = self._prepare_lookup(pids)
lengths2 = lengths.clone()
sentinel = self.strides[-1] + 1
order = torch.arange(pids.size(0), device='cuda' if self.use_gpu else 'cpu')
all_orders = []
all_tensors = []
all_lengths = []
all_masks = []
for stride in self.strides:
is_shorter = lengths2 <= stride
if is_shorter.sum() == 0:
continue
order_ = order[is_shorter]
tensor_, lengths_, mask_ = self._lookup_with_stride(stride, lengths[is_shorter], offsets[is_shorter])
all_orders.append(order_)
all_tensors.append(tensor_)
all_lengths.append(lengths_)
all_masks.append(mask_)
lengths2[is_shorter] = sentinel
assert lengths2.allclose(torch.tensor([sentinel], device='cuda' if self.use_gpu else 'cpu'))
all_orders = torch.cat(all_orders)
permute_idxs = torch.sort(all_orders).indices
return permute_idxs, all_tensors, torch.cat(all_lengths), all_masks
def _lookup_with_stride(self, stride, lengths, offsets):
tensor = self.views[stride][offsets]
if self.use_gpu:
tensor = tensor.cuda()
mask = _create_mask(lengths, stride, use_gpu=self.use_gpu)
# tensor = tensor[mask]
return tensor, lengths, mask
if __name__ == '__main__':
# lst = []
# for _ in range(10):
# lst.append(list(range(random.randint(0, 10))))
# print(lst)
# t = StridedTensor.from_nested_list(lst)
# print(t.lookup([9]))
import os
import pickle
index_path = '/future/u/okhattab/root/unit/indexes/2021/08/residual.NQ-micro'
with open(os.path.join(index_path, "centroid_idx_to_embedding_ids.pickle"), "rb") as f:
ivf_list = pickle.load(f)
assert len(ivf_list) == max(ivf_list.keys()) + 1
ivf_list = [ivf_list[i] for i in range(len(ivf_list))]
for x in ivf_list:
assert type(x) is list
assert type(x[0]) is int
ncentroids = len(ivf_list)
ivf = StridedTensor.from_nested_list(ivf_list)
import time
torch.cuda.synchronize()
t = time.time()
N = 100
for _ in range(N):
probed_centroids = torch.randint(0, ncentroids, size=(32, 8)).flatten()
emb_ids, emb_ids_lengths = ivf.lookup(probed_centroids).as_packed_tensor()
torch.cuda.synchronize()
print((time.time() - t) * 1000 / N, 'ms')
print(emb_ids_lengths)
slow_result = flatten([ivf_list[idx] for idx in probed_centroids.flatten().tolist()])
print(emb_ids.size(), len(slow_result))
for a, b in zip(slow_result, emb_ids.flatten().tolist()):
assert a == b, (a, b)
print("#> Done!")
print(ivf.lookup(probed_centroids).as_padded_tensor()[0].size())