欧卫
'add_app_files'
58627fa
import os
import time
import torch.multiprocessing as mp
from colbert.infra.run import Run
from colbert.infra.config import ColBERTConfig, RunConfig
from colbert.infra.launcher import Launcher
from colbert.utils.utils import create_directory, print_message
from colbert.indexing.collection_indexer import encode
class Indexer:
def __init__(self, checkpoint, config=None):
"""
Use Run().context() to choose the run's configuration. They are NOT extracted from `config`.
"""
self.index_path = None
self.checkpoint = checkpoint
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint)
self.config = ColBERTConfig.from_existing(self.checkpoint_config, config, Run().config)
self.configure(checkpoint=checkpoint)
def configure(self, **kw_args):
self.config.configure(**kw_args)
def get_index(self):
return self.index_path
def erase(self):
assert self.index_path is not None
directory = self.index_path
deleted = []
for filename in sorted(os.listdir(directory)):
filename = os.path.join(directory, filename)
delete = filename.endswith(".json")
delete = delete and ('metadata' in filename or 'doclen' in filename or 'plan' in filename)
delete = delete or filename.endswith(".pt")
if delete:
deleted.append(filename)
if len(deleted):
print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...")
time.sleep(20)
for filename in deleted:
os.remove(filename)
return deleted
def index(self, name, collection, overwrite=False):
assert overwrite in [True, False, 'reuse', 'resume']
self.configure(collection=collection, index_name=name, resume=overwrite=='resume')
self.configure(bsize=64, partitions=None)
self.index_path = self.config.index_path_
index_does_not_exist = (not os.path.exists(self.config.index_path_))
assert (overwrite in [True, 'reuse', 'resume']) or index_does_not_exist, self.config.index_path_
create_directory(self.config.index_path_)
if overwrite is True:
self.erase()
if index_does_not_exist or overwrite != 'reuse':
self.__launch(collection)
return self.index_path
def __launch(self, collection):
manager = mp.Manager()
shared_lists = [manager.list() for _ in range(self.config.nranks)]
shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
launcher = Launcher(encode)
launcher.launch(self.config, collection, shared_lists, shared_queues)