Spaces:
Runtime error
Runtime error
import torch | |
from colbert.utils.utils import load_checkpoint | |
from colbert.utils.amp import MixedPrecisionManager | |
from colbert.utils.utils import flatten | |
from baleen.utils.loaders import * | |
from baleen.condenser.model import ElectraReader | |
from baleen.condenser.tokenization import AnswerAwareTokenizer | |
class Condenser: | |
def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda', deviceL2='cuda'): | |
self.modelL1, self.maxlenL1 = self._load_model(checkpointL1, deviceL1) | |
self.modelL2, self.maxlenL2 = self._load_model(checkpointL2, deviceL2) | |
assert self.maxlenL1 == self.maxlenL2, "Add support for different maxlens: use two tokenizers." | |
self.amp, self.tokenizer = self._setup_inference(self.maxlenL2) | |
self.CollectionX, self.CollectionY = self._load_collection(collectionX_path) | |
def condense(self, query, backs, ranking): | |
stage1_preds = self._stage1(query, backs, ranking) | |
stage2_preds, stage2_preds_L3x = self._stage2(query, stage1_preds) | |
return stage1_preds, stage2_preds, stage2_preds_L3x | |
def _load_model(self, path, device): | |
model = torch.load(path, map_location='cpu') | |
ElectraModels = ['google/electra-base-discriminator', 'google/electra-large-discriminator'] | |
assert model['arguments']['model'] in ElectraModels, model['arguments'] | |
model = ElectraReader.from_pretrained(model['arguments']['model']) | |
checkpoint = load_checkpoint(path, model) | |
model = model.to(device) | |
model.eval() | |
maxlen = checkpoint['arguments']['maxlen'] | |
return model, maxlen | |
def _setup_inference(self, maxlen): | |
amp = MixedPrecisionManager(activated=True) | |
tokenizer = AnswerAwareTokenizer(total_maxlen=maxlen) | |
return amp, tokenizer | |
def _load_collection(self, collectionX_path): | |
CollectionX = {} | |
CollectionY = {} | |
with open(collectionX_path) as f: | |
for line_idx, line in enumerate(f): | |
line = ujson.loads(line) | |
assert type(line['text']) is list | |
assert line['pid'] == line_idx, (line_idx, line) | |
passage = [line['title']] + line['text'] | |
CollectionX[line_idx] = passage | |
passage = [line['title'] + ' | ' + sentence for sentence in line['text']] | |
for idx, sentence in enumerate(passage): | |
CollectionY[(line_idx, idx)] = sentence | |
return CollectionX, CollectionY | |
def _stage1(self, query, BACKS, ranking, TOPK=9): | |
model = self.modelL1 | |
with torch.inference_mode(): | |
backs = [self.CollectionY[(pid, sid)] for pid, sid in BACKS if (pid, sid) in self.CollectionY] | |
backs = [query] + backs | |
query = ' # '.join(backs) | |
# print(query) | |
# print(backs) | |
passages = [] | |
actual_ranking = [] | |
for pid in ranking: | |
actual_ranking.append(pid) | |
psg = self.CollectionX[pid] | |
psg = ' [MASK] '.join(psg) | |
passages.append(psg) | |
obj = self.tokenizer.process([query], passages, None) | |
with self.amp.context(): | |
scores = model(obj.encoding.to(model.device)).float() | |
pids = [[pid] * scores.size(1) for pid in actual_ranking] | |
pids = flatten(pids) | |
sids = [list(range(scores.size(1))) for pid in actual_ranking] | |
sids = flatten(sids) | |
scores = scores.view(-1) | |
topk = scores.topk(min(TOPK, len(scores))).indices.tolist() | |
topk_pids = [pids[idx] for idx in topk] | |
topk_sids = [sids[idx] for idx in topk] | |
preds = [(pid, sid) for pid, sid in zip(topk_pids, topk_sids)] | |
pred_plus = BACKS + preds | |
pred_plus = f7(list(map(tuple, pred_plus)))[:TOPK] | |
return pred_plus | |
def _stage2(self, query, preds): | |
model = self.modelL2 | |
psgX = [self.CollectionY[(pid, sid)] for pid, sid in preds if (pid, sid) in self.CollectionY] | |
psg = ' [MASK] '.join([''] + psgX) | |
passages = [psg] | |
# print(passages) | |
obj = self.tokenizer.process([query], passages, None) | |
with self.amp.context(): | |
scores = model(obj.encoding.to(model.device)).float() | |
scores = scores.view(-1).tolist() | |
preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)] | |
preds = sorted(preds, reverse=True)[:5] | |
preds_L3x = [x for score, x in preds if score > min(0, preds[1][0] - 1e-10)] # Take at least 2! | |
preds = [x for score, x in preds if score > 0] | |
earliest_pids = f7([pid for pid, _ in preds_L3x])[:4] # Take at most 4 docs. | |
preds_L3x = [(pid, sid) for pid, sid in preds_L3x if pid in earliest_pids] | |
assert len(preds_L3x) >= 2 | |
assert len(f7([pid for pid, _ in preds_L3x])) <= 4 | |
return preds, preds_L3x | |