Spaces:
Runtime error
Runtime error
from baleen.utils.loaders import * | |
from baleen.condenser.condense import Condenser | |
class Baleen: | |
def __init__(self, collectionX_path: str, searcher, condenser: Condenser): | |
self.collectionX = load_collectionX(collectionX_path) | |
self.searcher = searcher | |
self.condenser = condenser | |
def search(self, query, num_hops, depth=100, verbose=False): | |
assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}." | |
k = depth // num_hops | |
searcher = self.searcher | |
condenser = self.condenser | |
collectionX = self.collectionX | |
facts = [] | |
stage1_preds = None | |
context = None | |
pids_bag = set() | |
for hop_idx in range(0, num_hops): | |
ranking = list(zip(*searcher.search(query, context=context, k=depth))) | |
ranking_ = [] | |
facts_pids = set([pid for pid, _ in facts]) | |
for pid, rank, score in ranking: | |
# print(f'[{score}] \t\t {searcher.collection[pid]}') | |
if len(ranking_) < k and pid not in facts_pids: | |
ranking_.append(pid) | |
if len(pids_bag) < k * (hop_idx+1): | |
pids_bag.add(pid) | |
stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_) | |
context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts]) | |
assert len(pids_bag) == depth | |
return stage2_L3x, pids_bag, stage1_preds | |