欧卫
'add_app_files'
58627fa
raw
history blame
2.96 kB
import os
import argparse
from collections import namedtuple
from datasets import load_dataset
from utility.utils.dpr import has_answer, DPR_normalize
import tqdm
from colbert import Indexer, Searcher
from colbert.infra import ColBERTConfig, RunConfig, Run
SquadExample = namedtuple("SquadExample", "id title context question answers")
def build_index_and_init_searcher(checkpoint, collection, experiment_dir):
nbits = 1 # encode each dimension with 1 bits
doc_maxlen = 180 # truncate passages at 180 tokens
experiment = f"e2etest.nbits={nbits}"
with Run().context(RunConfig(nranks=1)):
config = ColBERTConfig(
doc_maxlen=doc_maxlen,
nbits=nbits,
root=experiment_dir,
experiment=experiment,
)
indexer = Indexer(checkpoint, config=config)
indexer.index(name=experiment, collection=collection, overwrite=True)
config = ColBERTConfig(
root=experiment_dir,
experiment=experiment,
)
searcher = Searcher(
index=experiment,
config=config,
)
return searcher
def success_at_k(searcher, examples, k):
scores = []
for ex in tqdm.tqdm(examples):
scores.append(evaluate_retrieval_example(searcher, ex, k))
return sum(scores) / len(scores)
def evaluate_retrieval_example(searcher, ex, k):
results = searcher.search(ex.question, k=k)
for passage_id, passage_rank, passage_score in zip(*results):
passage = searcher.collection[passage_id]
score = has_answer([DPR_normalize(ans) for ans in ex.answers], passage)
if score:
return 1
return 0
def get_squad_split(squad, split="validation"):
fields = squad[split].features
data = zip(*[squad[split][field] for field in fields])
return [
SquadExample(eid, title, context, question, answers["text"])
for eid, title, context, question, answers in data
]
def main(args):
checkpoint = args.checkpoint
collection = args.collection
experiment_dir = args.expdir
# Start the test
k = 5
searcher = build_index_and_init_searcher(checkpoint, collection, experiment_dir)
squad = load_dataset("squad")
squad_dev = get_squad_split(squad)
success_rate = success_at_k(searcher, squad_dev[:1000], k)
assert success_rate > 0.93, f"success rate at {success_rate} is lower than expected"
print(f"test passed with success rate {success_rate}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Start end-to-end test.")
parser.add_argument(
"--checkpoint", type=str, required=True, help="Model checkpoint"
)
parser.add_argument(
"--collection", type=str, required=True, help="Path to collection"
)
parser.add_argument(
"--expdir", type=str, required=True, help="Experiment directory"
)
args = parser.parse_args()
main(args)