Spaces:
Runtime error
Runtime error
from colbert.infra.run import Run | |
import os | |
import ujson | |
from colbert.evaluation.loaders import load_queries | |
# TODO: Look up path in some global [per-thread or thread-safe] list. | |
# TODO: path could be a list of paths...? But then how can we tell it's not a list of queries.. | |
class Queries: | |
def __init__(self, path=None, data=None): | |
self.path = path | |
if data: | |
assert isinstance(data, dict), type(data) | |
self._load_data(data) or self._load_file(path) | |
def __len__(self): | |
return len(self.data) | |
def __iter__(self): | |
return iter(self.data.items()) | |
def provenance(self): | |
return self.path | |
def toDict(self): | |
return {'provenance': self.provenance()} | |
def _load_data(self, data): | |
if data is None: | |
return None | |
self.data = {} | |
self._qas = {} | |
for qid, content in data.items(): | |
if isinstance(content, dict): | |
self.data[qid] = content['question'] | |
self._qas[qid] = content | |
else: | |
self.data[qid] = content | |
if len(self._qas) == 0: | |
del self._qas | |
return True | |
def _load_file(self, path): | |
if not path.endswith('.json'): | |
self.data = load_queries(path) | |
return True | |
# Load QAs | |
self.data = {} | |
self._qas = {} | |
with open(path) as f: | |
for line in f: | |
qa = ujson.loads(line) | |
assert qa['qid'] not in self.data | |
self.data[qa['qid']] = qa['question'] | |
self._qas[qa['qid']] = qa | |
return self.data | |
def qas(self): | |
return dict(self._qas) | |
def __getitem__(self, key): | |
return self.data[key] | |
def keys(self): | |
return self.data.keys() | |
def values(self): | |
return self.data.values() | |
def items(self): | |
return self.data.items() | |
def save(self, new_path): | |
assert new_path.endswith('.tsv') | |
assert not os.path.exists(new_path), new_path | |
with Run().open(new_path, 'w') as f: | |
for qid, content in self.data.items(): | |
content = f'{qid}\t{content}\n' | |
f.write(content) | |
return f.name | |
def save_qas(self, new_path): | |
assert new_path.endswith('.json') | |
assert not os.path.exists(new_path), new_path | |
with open(new_path, 'w') as f: | |
for qid, qa in self._qas.items(): | |
qa['qid'] = qid | |
f.write(ujson.dumps(qa) + '\n') | |
def _load_tsv(self, path): | |
raise NotImplementedError | |
def _load_jsonl(self, path): | |
raise NotImplementedError | |
def cast(cls, obj): | |
if type(obj) is str: | |
return cls(path=obj) | |
if isinstance(obj, dict) or isinstance(obj, list): | |
return cls(data=obj) | |
if type(obj) is cls: | |
return obj | |
assert False, f"obj has type {type(obj)} which is not compatible with cast()" | |
# class QuerySet: | |
# def __init__(self, *paths, renumber=False): | |
# self.paths = paths | |
# self.original_queries = [load_queries(path) for path in paths] | |
# if renumber: | |
# self.queries = flatten([q.values() for q in self.original_queries]) | |
# self.queries = {idx: text for idx, text in enumerate(self.queries)} | |
# else: | |
# self.queries = {} | |
# for queries in self.original_queries: | |
# assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \ | |
# "renumber=False requires non-overlapping query IDs" | |
# self.queries.update(queries) | |
# assert len(self.queries) == sum(map(len, self.original_queries)) | |
# def todict(self): | |
# return dict(self.queries) | |
# def tolist(self): | |
# return list(self.queries.values()) | |
# def query_sets(self): | |
# return self.original_queries | |
# def split_rankings(self, rankings): | |
# assert type(rankings) is list | |
# assert len(rankings) == len(self.queries) | |
# sub_rankings = [] | |
# offset = 0 | |
# for source in self.original_queries: | |
# sub_rankings.append(rankings[offset:offset+len(source)]) | |
# offset += len(source) | |
# return sub_rankings | |