Spaces:
Runtime error
Runtime error
import os | |
import time | |
import torch.multiprocessing as mp | |
from colbert.infra.run import Run | |
from colbert.infra.config import ColBERTConfig, RunConfig | |
from colbert.infra.launcher import Launcher | |
from colbert.utils.utils import create_directory, print_message | |
from colbert.indexing.collection_indexer import encode | |
class Indexer: | |
def __init__(self, checkpoint, config=None): | |
""" | |
Use Run().context() to choose the run's configuration. They are NOT extracted from `config`. | |
""" | |
self.index_path = None | |
self.checkpoint = checkpoint | |
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint) | |
self.config = ColBERTConfig.from_existing(self.checkpoint_config, config, Run().config) | |
self.configure(checkpoint=checkpoint) | |
def configure(self, **kw_args): | |
self.config.configure(**kw_args) | |
def get_index(self): | |
return self.index_path | |
def erase(self): | |
assert self.index_path is not None | |
directory = self.index_path | |
deleted = [] | |
for filename in sorted(os.listdir(directory)): | |
filename = os.path.join(directory, filename) | |
delete = filename.endswith(".json") | |
delete = delete and ('metadata' in filename or 'doclen' in filename or 'plan' in filename) | |
delete = delete or filename.endswith(".pt") | |
if delete: | |
deleted.append(filename) | |
if len(deleted): | |
print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...") | |
time.sleep(20) | |
for filename in deleted: | |
os.remove(filename) | |
return deleted | |
def index(self, name, collection, overwrite=False): | |
assert overwrite in [True, False, 'reuse', 'resume'] | |
self.configure(collection=collection, index_name=name, resume=overwrite=='resume') | |
self.configure(bsize=64, partitions=None) | |
self.index_path = self.config.index_path_ | |
index_does_not_exist = (not os.path.exists(self.config.index_path_)) | |
assert (overwrite in [True, 'reuse', 'resume']) or index_does_not_exist, self.config.index_path_ | |
create_directory(self.config.index_path_) | |
if overwrite is True: | |
self.erase() | |
if index_does_not_exist or overwrite != 'reuse': | |
self.__launch(collection) | |
return self.index_path | |
def __launch(self, collection): | |
manager = mp.Manager() | |
shared_lists = [manager.list() for _ in range(self.config.nranks)] | |
shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)] | |
launcher = Launcher(encode) | |
launcher.launch(self.config, collection, shared_lists, shared_queues) | |