欧卫
'add_app_files'
58627fa
raw
history blame
No virus
2.72 kB
from colbert.infra.run import Run
import os
import ujson
from colbert.utils.utils import print_message
from colbert.infra.provenance import Provenance
from utility.utils.save_metadata import get_metadata_only
class Examples:
def __init__(self, path=None, data=None, nway=None, provenance=None):
self.__provenance = provenance or path or Provenance()
self.nway = nway
self.path = path
self.data = data or self._load_file(path)
def provenance(self):
return self.__provenance
def toDict(self):
return self.provenance()
def _load_file(self, path):
nway = self.nway + 1 if self.nway else self.nway
examples = []
with open(path) as f:
for line in f:
example = ujson.loads(line)[:nway]
examples.append(example)
return examples
def tolist(self, rank=None, nranks=None):
"""
NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
In particular, each subset is perfectly represented in every batch! However, since we never
repeat passes over the data, we never repeat any particular triple, and the split across
nodes is random (since the underlying file is pre-shuffled), there's no concern here.
"""
if rank or nranks:
assert rank in range(nranks), (rank, nranks)
return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank
return list(self.data)
def save(self, new_path):
assert 'json' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."
print_message(f"#> Writing {len(self.data) / 1000_000.0}M examples to {new_path}")
with Run().open(new_path, 'w') as f:
for example in self.data:
ujson.dump(example, f)
f.write('\n')
output_path = f.name
print_message(f"#> Saved examples with {len(self.data)} lines to {f.name}")
with Run().open(f'{new_path}.meta', 'w') as f:
d = {}
d['metadata'] = get_metadata_only()
d['provenance'] = self.provenance()
line = ujson.dumps(d, indent=4)
f.write(line)
return output_path
@classmethod
def cast(cls, obj, nway=None):
if type(obj) is str:
return cls(path=obj, nway=nway)
if isinstance(obj, list):
return cls(data=obj, nway=nway)
if type(obj) is cls:
assert nway is None, nway
return obj
assert False, f"obj has type {type(obj)} which is not compatible with cast()"