history blame
No virus
3.23 kB
# Could be .tsv or .json. The latter always allows more customization via optional parameters.
# I think it could be worth doing some kind of parallel reads too, if the file exceeds 1 GiBs.
# Just need to use a datastructure that shares things across processes without too much pickling.
# I think multiprocessing.Manager can do that!
import os
import itertools
from colbert.evaluation.loaders import load_collection
from colbert.infra.run import Run
class Collection:
def __init__(self, path=None, data=None):
self.path = path
self.data = data or self._load_file(path)
def __iter__(self):
# TODO: If __data isn't there, stream from disk!
return self.data.__iter__()
def __getitem__(self, item):
# TODO: Load from disk the first time this is called. Unless self.data is already not None.
return self.data[item]
def __len__(self):
# TODO: Load here too. Basically, let's make data a property function and, on first call, either load or get __data.
return len(self.data)
def _load_file(self, path):
self.path = path
return self._load_tsv(path) if path.endswith('.tsv') else self._load_jsonl(path)
def _load_tsv(self, path):
return load_collection(path)
def _load_jsonl(self, path):
raise NotImplementedError()
def provenance(self):
return self.path
def toDict(self):
return {'provenance': self.provenance()}
def save(self, new_path):
assert new_path.endswith('.tsv'), "TODO: Support .json[l] too."
assert not os.path.exists(new_path), new_path
with Run().open(new_path, 'w') as f:
# TODO: expects content to always be a string here; no separate title!
for pid, content in enumerate(self.data):
content = f'{pid}\t{content}\n'
return f.name
def enumerate(self, rank):
for _, offset, passages in self.enumerate_batches(rank=rank):
for idx, passage in enumerate(passages):
yield (offset + idx, passage)
def enumerate_batches(self, rank, chunksize=None):
assert rank is not None, "TODO: Add support for the rank=None case."
chunksize = chunksize or self.get_chunksize()
offset = 0
iterator = iter(self)
for chunk_idx, owner in enumerate(itertools.cycle(range(Run().nranks))):
L = [line for _, line in zip(range(chunksize), iterator)]
if len(L) > 0 and owner == rank:
yield (chunk_idx, offset, L)
offset += len(L)
if len(L) < chunksize:
def get_chunksize(self):
return min(25_000, 1 + len(self) // Run().nranks) # 25k is great, 10k allows things to reside on GPU??
def cast(cls, obj):
if type(obj) is str:
return cls(path=obj)
if type(obj) is list:
return cls(data=obj)
if type(obj) is cls:
return obj
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
# TODO: Look up path in some global [per-thread or thread-safe] list.