欧卫
'add_app_files'
58627fa
raw
history blame
6.4 kB
import os
import ujson
import torch
import random
from collections import defaultdict, OrderedDict
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message, load_checkpoint
from colbert.evaluation.load_model import load_model
from colbert.utils.runs import Run
def load_queries(queries_path):
queries = OrderedDict()
print_message("#> Loading the queries from", queries_path, "...")
with open(queries_path) as f:
for line in f:
qid, query, *_ = line.strip().split('\t')
qid = int(qid)
assert (qid not in queries), ("Query QID", qid, "is repeated!")
queries[qid] = query
print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")
return queries
def load_qrels(qrels_path):
if qrels_path is None:
return None
print_message("#> Loading qrels from", qrels_path, "...")
qrels = OrderedDict()
with open(qrels_path, mode='r', encoding="utf-8") as f:
for line in f:
qid, x, pid, y = map(int, line.strip().split('\t'))
assert x == 0 and y == 1
qrels[qid] = qrels.get(qid, [])
qrels[qid].append(pid)
# assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)
for qid in qrels:
qrels[qid] = list(set(qrels[qid]))
avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)
print_message("#> Loaded qrels for", len(qrels), "unique queries with",
avg_positive, "positives per query on average.\n")
return qrels
def load_topK(topK_path):
queries = OrderedDict()
topK_docs = OrderedDict()
topK_pids = OrderedDict()
print_message("#> Loading the top-k per query from", topK_path, "...")
with open(topK_path) as f:
for line_idx, line in enumerate(f):
if line_idx and line_idx % (10*1000*1000) == 0:
print(line_idx, end=' ', flush=True)
qid, pid, query, passage = line.split('\t')
qid, pid = int(qid), int(pid)
assert (qid not in queries) or (queries[qid] == query)
queries[qid] = query
topK_docs[qid] = topK_docs.get(qid, [])
topK_docs[qid].append(passage)
topK_pids[qid] = topK_pids.get(qid, [])
topK_pids[qid].append(pid)
print()
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
Ks = [len(topK_pids[qid]) for qid in topK_pids]
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")
return queries, topK_docs, topK_pids
def load_topK_pids(topK_path, qrels):
topK_pids = defaultdict(list)
topK_positives = defaultdict(list)
print_message("#> Loading the top-k PIDs per query from", topK_path, "...")
with open(topK_path) as f:
for line_idx, line in enumerate(f):
if line_idx and line_idx % (10*1000*1000) == 0:
print(line_idx, end=' ', flush=True)
qid, pid, *rest = line.strip().split('\t')
qid, pid = int(qid), int(pid)
topK_pids[qid].append(pid)
assert len(rest) in [1, 2, 3]
if len(rest) > 1:
*_, label = rest
label = int(label)
assert label in [0, 1]
if label >= 1:
topK_positives[qid].append(pid)
print()
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)
# Make them sets for fast lookups later
topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}
Ks = [len(topK_pids[qid]) for qid in topK_pids]
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")
if len(topK_positives) == 0:
topK_positives = None
else:
assert len(topK_pids) >= len(topK_positives)
for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
topK_positives[qid] = []
assert len(topK_pids) == len(topK_positives)
avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)
print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
avg_positive, "positives per query on average.\n")
assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"
if topK_positives is None:
topK_positives = qrels
return topK_pids, topK_positives
def load_collection(collection_path):
print_message("#> Loading collection...")
collection = []
with open(collection_path) as f:
for line_idx, line in enumerate(f):
if line_idx % (1000*1000) == 0:
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)
pid, passage, *rest = line.strip('\n\r ').split('\t')
assert pid == 'id' or int(pid) == line_idx
if len(rest) >= 1:
title = rest[0]
passage = title + ' | ' + passage
collection.append(passage)
print()
return collection
def load_colbert(args, do_print=True):
colbert, checkpoint = load_model(args, do_print)
# TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
# I.e., not their purely (i.e., training) default values.
for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
if 'arguments' in checkpoint and hasattr(args, k):
if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
a, b = checkpoint['arguments'][k], getattr(args, k)
Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")
if 'arguments' in checkpoint:
if args.rank < 1:
print(ujson.dumps(checkpoint['arguments'], indent=4))
if do_print:
print('\n')
return colbert, checkpoint