import torch from colbert.utils.utils import load_checkpoint from colbert.utils.amp import MixedPrecisionManager from colbert.utils.utils import flatten from baleen.utils.loaders import * from baleen.condenser.model import ElectraReader from baleen.condenser.tokenization import AnswerAwareTokenizer class Condenser: def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda', deviceL2='cuda'): self.modelL1, self.maxlenL1 = self._load_model(checkpointL1, deviceL1) self.modelL2, self.maxlenL2 = self._load_model(checkpointL2, deviceL2) assert self.maxlenL1 == self.maxlenL2, "Add support for different maxlens: use two tokenizers." self.amp, self.tokenizer = self._setup_inference(self.maxlenL2) self.CollectionX, self.CollectionY = self._load_collection(collectionX_path) def condense(self, query, backs, ranking): stage1_preds = self._stage1(query, backs, ranking) stage2_preds, stage2_preds_L3x = self._stage2(query, stage1_preds) return stage1_preds, stage2_preds, stage2_preds_L3x def _load_model(self, path, device): model = torch.load(path, map_location='cpu') ElectraModels = ['google/electra-base-discriminator', 'google/electra-large-discriminator'] assert model['arguments']['model'] in ElectraModels, model['arguments'] model = ElectraReader.from_pretrained(model['arguments']['model']) checkpoint = load_checkpoint(path, model) model = model.to(device) model.eval() maxlen = checkpoint['arguments']['maxlen'] return model, maxlen def _setup_inference(self, maxlen): amp = MixedPrecisionManager(activated=True) tokenizer = AnswerAwareTokenizer(total_maxlen=maxlen) return amp, tokenizer def _load_collection(self, collectionX_path): CollectionX = {} CollectionY = {} with open(collectionX_path) as f: for line_idx, line in enumerate(f): line = ujson.loads(line) assert type(line['text']) is list assert line['pid'] == line_idx, (line_idx, line) passage = [line['title']] + line['text'] CollectionX[line_idx] = passage passage = [line['title'] + ' | ' + sentence for sentence in line['text']] for idx, sentence in enumerate(passage): CollectionY[(line_idx, idx)] = sentence return CollectionX, CollectionY def _stage1(self, query, BACKS, ranking, TOPK=9): model = self.modelL1 with torch.inference_mode(): backs = [self.CollectionY[(pid, sid)] for pid, sid in BACKS if (pid, sid) in self.CollectionY] backs = [query] + backs query = ' # '.join(backs) # print(query) # print(backs) passages = [] actual_ranking = [] for pid in ranking: actual_ranking.append(pid) psg = self.CollectionX[pid] psg = ' [MASK] '.join(psg) passages.append(psg) obj = self.tokenizer.process([query], passages, None) with self.amp.context(): scores = model(obj.encoding.to(model.device)).float() pids = [[pid] * scores.size(1) for pid in actual_ranking] pids = flatten(pids) sids = [list(range(scores.size(1))) for pid in actual_ranking] sids = flatten(sids) scores = scores.view(-1) topk = scores.topk(min(TOPK, len(scores))).indices.tolist() topk_pids = [pids[idx] for idx in topk] topk_sids = [sids[idx] for idx in topk] preds = [(pid, sid) for pid, sid in zip(topk_pids, topk_sids)] pred_plus = BACKS + preds pred_plus = f7(list(map(tuple, pred_plus)))[:TOPK] return pred_plus def _stage2(self, query, preds): model = self.modelL2 psgX = [self.CollectionY[(pid, sid)] for pid, sid in preds if (pid, sid) in self.CollectionY] psg = ' [MASK] '.join([''] + psgX) passages = [psg] # print(passages) obj = self.tokenizer.process([query], passages, None) with self.amp.context(): scores = model(obj.encoding.to(model.device)).float() scores = scores.view(-1).tolist() preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)] preds = sorted(preds, reverse=True)[:5] preds_L3x = [x for score, x in preds if score > min(0, preds[1][0] - 1e-10)] # Take at least 2! preds = [x for score, x in preds if score > 0] earliest_pids = f7([pid for pid, _ in preds_L3x])[:4] # Take at most 4 docs. preds_L3x = [(pid, sid) for pid, sid in preds_L3x if pid in earliest_pids] assert len(preds_L3x) >= 2 assert len(f7([pid for pid, _ in preds_L3x])) <= 4 return preds, preds_L3x