欧卫
'add_app_files'
58627fa
raw
history blame
3.47 kB
import os
import sys
import time
import __main__
import traceback
# import mlflow
import colbert.utils.distributed as distributed
from contextlib import contextmanager
from colbert.utils.logging import Logger
from colbert.utils.utils import timestamp, create_directory, print_message
class _RunManager():
def __init__(self):
self.experiments_root = None
self.experiment = None
self.path = None
self.script = self._get_script_name()
self.name = self._generate_default_run_name()
self.original_name = self.name
self.exit_status = 'FINISHED'
self._logger = None
self.start_time = time.time()
def init(self, rank, root, experiment, name):
assert '/' not in experiment, experiment
assert '/' not in name, name
self.experiments_root = os.path.abspath(root)
self.experiment = experiment
self.name = name
self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)
if rank < 1:
if os.path.exists(self.path):
print('\n\n')
print_message("It seems that ", self.path, " already exists.")
print_message("Do you want to overwrite it? \t yes/no \n")
# TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds.
response = input()
if response.strip() != 'yes':
assert not os.path.exists(self.path), self.path
else:
create_directory(self.path)
distributed.barrier(rank)
self._logger = Logger(rank, self)
self._log_args = self._logger._log_args
self.warn = self._logger.warn
self.info = self._logger.info
self.info_all = self._logger.info_all
self.log_metric = self._logger.log_metric
self.log_new_artifact = self._logger.log_new_artifact
def _generate_default_run_name(self):
return timestamp()
def _get_script_name(self):
return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none'
@contextmanager
def context(self, consider_failed_if_interrupted=True):
try:
yield
except KeyboardInterrupt as ex:
print('\n\nInterrupted\n\n')
self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
self._logger._log_all_artifacts()
if consider_failed_if_interrupted:
self.exit_status = 'KILLED' # mlflow.entities.RunStatus.KILLED
sys.exit(128 + 2)
except Exception as ex:
self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
self._logger._log_all_artifacts()
self.exit_status = 'FAILED' # mlflow.entities.RunStatus.FAILED
raise ex
finally:
total_seconds = str(time.time() - self.start_time) + '\n'
original_name = str(self.original_name)
name = str(self.name)
self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds)
self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name)
self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name)
self._logger._log_all_artifacts()
# mlflow.end_run(status=self.exit_status)
Run = _RunManager()