欧卫
'add_app_files'
58627fa
raw
history blame
4.39 kB
from colbert.infra.run import Run
import os
import ujson
from colbert.evaluation.loaders import load_queries
# TODO: Look up path in some global [per-thread or thread-safe] list.
# TODO: path could be a list of paths...? But then how can we tell it's not a list of queries..
class Queries:
def __init__(self, path=None, data=None):
self.path = path
if data:
assert isinstance(data, dict), type(data)
self._load_data(data) or self._load_file(path)
def __len__(self):
return len(self.data)
def __iter__(self):
return iter(self.data.items())
def provenance(self):
return self.path
def toDict(self):
return {'provenance': self.provenance()}
def _load_data(self, data):
if data is None:
return None
self.data = {}
self._qas = {}
for qid, content in data.items():
if isinstance(content, dict):
self.data[qid] = content['question']
self._qas[qid] = content
else:
self.data[qid] = content
if len(self._qas) == 0:
del self._qas
return True
def _load_file(self, path):
if not path.endswith('.json'):
self.data = load_queries(path)
return True
# Load QAs
self.data = {}
self._qas = {}
with open(path) as f:
for line in f:
qa = ujson.loads(line)
assert qa['qid'] not in self.data
self.data[qa['qid']] = qa['question']
self._qas[qa['qid']] = qa
return self.data
def qas(self):
return dict(self._qas)
def __getitem__(self, key):
return self.data[key]
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
def save(self, new_path):
assert new_path.endswith('.tsv')
assert not os.path.exists(new_path), new_path
with Run().open(new_path, 'w') as f:
for qid, content in self.data.items():
content = f'{qid}\t{content}\n'
f.write(content)
return f.name
def save_qas(self, new_path):
assert new_path.endswith('.json')
assert not os.path.exists(new_path), new_path
with open(new_path, 'w') as f:
for qid, qa in self._qas.items():
qa['qid'] = qid
f.write(ujson.dumps(qa) + '\n')
def _load_tsv(self, path):
raise NotImplementedError
def _load_jsonl(self, path):
raise NotImplementedError
@classmethod
def cast(cls, obj):
if type(obj) is str:
return cls(path=obj)
if isinstance(obj, dict) or isinstance(obj, list):
return cls(data=obj)
if type(obj) is cls:
return obj
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
# class QuerySet:
# def __init__(self, *paths, renumber=False):
# self.paths = paths
# self.original_queries = [load_queries(path) for path in paths]
# if renumber:
# self.queries = flatten([q.values() for q in self.original_queries])
# self.queries = {idx: text for idx, text in enumerate(self.queries)}
# else:
# self.queries = {}
# for queries in self.original_queries:
# assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \
# "renumber=False requires non-overlapping query IDs"
# self.queries.update(queries)
# assert len(self.queries) == sum(map(len, self.original_queries))
# def todict(self):
# return dict(self.queries)
# def tolist(self):
# return list(self.queries.values())
# def query_sets(self):
# return self.original_queries
# def split_rankings(self, rankings):
# assert type(rankings) is list
# assert len(rankings) == len(self.queries)
# sub_rankings = []
# offset = 0
# for source in self.original_queries:
# sub_rankings.append(rankings[offset:offset+len(source)])
# offset += len(source)
# return sub_rankings