Spaces:
Sleeping
Sleeping
File size: 1,689 Bytes
46ff99b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
from dinov2.logging import setup_logging
from dinov2.run.submit import get_args_parser, submit_jobs
from dinov2.train import get_args_parser as get_train_args_parser
logger = logging.getLogger("dinov2")
class Trainer(object):
def __init__(self, args):
self.args = args
def __call__(self):
from dinov2.train import main as train_main
self._setup_args()
train_main(self.args)
def checkpoint(self):
import submitit
logger.info(f"Requeuing {self.args}")
empty = type(self)(self.args)
return submitit.helpers.DelayedSubmission(empty)
def _setup_args(self):
import submitit
job_env = submitit.JobEnvironment()
self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
logger.info(
f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}"
)
logger.info(f"Args: {self.args}")
def main():
description = "Submitit launcher for DINOv2 training"
train_args_parser = get_train_args_parser(add_help=False)
parents = [train_args_parser]
args_parser = get_args_parser(description=description, parents=parents)
args = args_parser.parse_args()
setup_logging()
assert os.path.exists(args.config_file), "Configuration file does not exist!"
submit_jobs(Trainer, args, name="dinov2:train")
return 0
if __name__ == "__main__":
sys.exit(main())
|