Spaces:
Runtime error
Runtime error
File size: 1,618 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from baleen.utils.loaders import *
from baleen.condenser.condense import Condenser
class Baleen:
def __init__(self, collectionX_path: str, searcher, condenser: Condenser):
self.collectionX = load_collectionX(collectionX_path)
self.searcher = searcher
self.condenser = condenser
def search(self, query, num_hops, depth=100, verbose=False):
assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}."
k = depth // num_hops
searcher = self.searcher
condenser = self.condenser
collectionX = self.collectionX
facts = []
stage1_preds = None
context = None
pids_bag = set()
for hop_idx in range(0, num_hops):
ranking = list(zip(*searcher.search(query, context=context, k=depth)))
ranking_ = []
facts_pids = set([pid for pid, _ in facts])
for pid, rank, score in ranking:
# print(f'[{score}] \t\t {searcher.collection[pid]}')
if len(ranking_) < k and pid not in facts_pids:
ranking_.append(pid)
if len(pids_bag) < k * (hop_idx+1):
pids_bag.add(pid)
stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])
assert len(pids_bag) == depth
return stage2_L3x, pids_bag, stage1_preds
|