Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import time | |
import __main__ | |
import traceback | |
# import mlflow | |
import colbert.utils.distributed as distributed | |
from contextlib import contextmanager | |
from colbert.utils.logging import Logger | |
from colbert.utils.utils import timestamp, create_directory, print_message | |
class _RunManager(): | |
def __init__(self): | |
self.experiments_root = None | |
self.experiment = None | |
self.path = None | |
self.script = self._get_script_name() | |
self.name = self._generate_default_run_name() | |
self.original_name = self.name | |
self.exit_status = 'FINISHED' | |
self._logger = None | |
self.start_time = time.time() | |
def init(self, rank, root, experiment, name): | |
assert '/' not in experiment, experiment | |
assert '/' not in name, name | |
self.experiments_root = os.path.abspath(root) | |
self.experiment = experiment | |
self.name = name | |
self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name) | |
if rank < 1: | |
if os.path.exists(self.path): | |
print('\n\n') | |
print_message("It seems that ", self.path, " already exists.") | |
print_message("Do you want to overwrite it? \t yes/no \n") | |
# TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds. | |
response = input() | |
if response.strip() != 'yes': | |
assert not os.path.exists(self.path), self.path | |
else: | |
create_directory(self.path) | |
distributed.barrier(rank) | |
self._logger = Logger(rank, self) | |
self._log_args = self._logger._log_args | |
self.warn = self._logger.warn | |
self.info = self._logger.info | |
self.info_all = self._logger.info_all | |
self.log_metric = self._logger.log_metric | |
self.log_new_artifact = self._logger.log_new_artifact | |
def _generate_default_run_name(self): | |
return timestamp() | |
def _get_script_name(self): | |
return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none' | |
def context(self, consider_failed_if_interrupted=True): | |
try: | |
yield | |
except KeyboardInterrupt as ex: | |
print('\n\nInterrupted\n\n') | |
self._logger._log_exception(ex.__class__, ex, ex.__traceback__) | |
self._logger._log_all_artifacts() | |
if consider_failed_if_interrupted: | |
self.exit_status = 'KILLED' # mlflow.entities.RunStatus.KILLED | |
sys.exit(128 + 2) | |
except Exception as ex: | |
self._logger._log_exception(ex.__class__, ex, ex.__traceback__) | |
self._logger._log_all_artifacts() | |
self.exit_status = 'FAILED' # mlflow.entities.RunStatus.FAILED | |
raise ex | |
finally: | |
total_seconds = str(time.time() - self.start_time) + '\n' | |
original_name = str(self.original_name) | |
name = str(self.name) | |
self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds) | |
self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name) | |
self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name) | |
self._logger._log_all_artifacts() | |
# mlflow.end_run(status=self.exit_status) | |
Run = _RunManager() | |