Spaces:
Runtime error
Runtime error
File size: 2,624 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import ujson
from functools import partial
from colbert.infra.config.config import ColBERTConfig
from colbert.utils.utils import print_message, zipstar
from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples
from colbert.evaluation.loaders import load_collection
from colbert.data.collection import Collection
from colbert.data.queries import Queries
from colbert.data.examples import Examples
# from colbert.utils.runs import Run
class LazyBatcher():
def __init__(self, config: ColBERTConfig, triples, queries, collection, rank=0, nranks=1):
self.bsize, self.accumsteps = config.bsize, config.accumsteps
self.nway = config.nway
self.query_tokenizer = QueryTokenizer(config)
self.doc_tokenizer = DocTokenizer(config)
self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
self.position = 0
self.triples = Examples.cast(triples, nway=self.nway).tolist(rank, nranks)
self.queries = Queries.cast(queries)
self.collection = Collection.cast(collection)
def __iter__(self):
return self
def __len__(self):
return len(self.triples)
def __next__(self):
offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
self.position = endpos
if offset + self.bsize > len(self.triples):
raise StopIteration
all_queries, all_passages, all_scores = [], [], []
for position in range(offset, endpos):
query, *pids = self.triples[position]
pids = pids[:self.nway]
query = self.queries[query]
try:
pids, scores = zipstar(pids)
except:
scores = []
passages = [self.collection[pid] for pid in pids]
all_queries.append(query)
all_passages.extend(passages)
all_scores.extend(scores)
assert len(all_scores) in [0, len(all_passages)], len(all_scores)
return self.collate(all_queries, all_passages, all_scores)
def collate(self, queries, passages, scores):
assert len(queries) == self.bsize
assert len(passages) == self.nway * self.bsize
return self.tensorize_triples(queries, passages, scores, self.bsize // self.accumsteps, self.nway)
# def skip_to_batch(self, batch_idx, intended_batch_size):
# Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.')
# self.position = intended_batch_size * batch_idx
|