imw34531's picture
Upload folder using huggingface_hub
87e21d1 verified
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import argparse
import datetime
import os
import os.path as osp
import subprocess
from termcolor import colored
def supports_gpus_per_node():
VILA_DATASETS = os.environ.get("VILA_DATASETS", "")
if "eos" in VILA_DATASETS.lower():
return False
return True
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--job-name", "-J", type=str, required=True)
parser.add_argument("--nodes", "-N", type=int, default=1)
parser.add_argument("--gpus-per-node", type=int, default=8)
parser.add_argument("--mode", "-m", type=str, default="train")
parser.add_argument("--time", "-t", type=str, default="4:00:00")
parser.add_argument("--timedelta", type=int, default=5)
parser.add_argument("--output-dir", type=str)
parser.add_argument("--max-retry", type=int, default=-1)
# -1: indicates none, for train jobs, it will be set 3 and otherwise 1
parser.add_argument("--pty", action="store_true")
parser.add_argument("cmd", nargs=argparse.REMAINDER)
args = parser.parse_args()
if args.max_retry < 0:
if args.mode == "train":
args.max_retry = 3
else:
args.max_retry = 0
# Generate run name and output directory
if "%t" in args.job_name:
args.job_name = args.job_name.replace("%t", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
if args.output_dir is None:
args.output_dir = os.path.join("runs", args.mode, args.job_name)
output_dir = os.path.expanduser(args.output_dir)
# Calculate the timeout
time = datetime.datetime.strptime(args.time, "%H:%M:%S")
if time < datetime.datetime.strptime("0:01:00", "%H:%M:%S"):
raise ValueError("Time must be at least 1 minutes")
timeout = time - datetime.timedelta(minutes=args.timedelta)
timeout = timeout.hour * 60 + timeout.minute
timeout = f"{timeout}m"
# Get SLURM account and partition
if "VILA_SLURM_ACCOUNT" not in os.environ or "VILA_SLURM_PARTITION" not in os.environ:
raise ValueError("`VILA_SLURM_ACCOUNT` and `VILA_SLURM_PARTITION` must be set in the environment.")
account = os.environ["VILA_SLURM_ACCOUNT"]
partition = os.environ["VILA_SLURM_PARTITION"]
# Set environment variables
env = os.environ.copy()
env["RUN_NAME"] = args.job_name
env["OUTPUT_DIR"] = output_dir
# Compose the SLURM command
cmd = ["srun"]
cmd += ["--account", account]
cmd += ["--partition", partition]
cmd += ["--job-name", f"{account}:{args.mode}/{args.job_name}"]
if not args.pty:
# Redirect output to files if not pty / interactive
cmd += ["--output", f"{output_dir}/slurm/%J.out"]
cmd += ["--error", f"{output_dir}/slurm/%J.err"]
cmd += ["--nodes", str(args.nodes)]
if supports_gpus_per_node():
# eos slurm does not support gpus-per-node option
cmd += ["--gpus-per-node", str(args.gpus_per_node)]
cmd += ["--time", args.time]
cmd += ["--exclusive"]
cmd += ["timeout", timeout]
cmd += args.cmd
full_cmd = " ".join(cmd)
if os.environ.get("SLURM_JOB_ID"):
print(colored("Running inside slurm nodes detected", "yellow"))
full_cmd = " ".join(args.cmd)
print(colored(full_cmd, attrs=["bold"]))
# Run the job and resume if it times out
fail_times = 0
while True:
returncode = subprocess.run(full_cmd, env=env, shell=True).returncode
print(f"returncode: {returncode}")
if returncode == 0:
print("Job finished successfully!")
break
if returncode != 124:
fail_times += 1
if fail_times > args.max_retry:
break
print(f"Job failed, retrying {fail_times} / {args.max_retry}")
else:
fail_times = 0
print("Job timed out, retrying...")
# Exit with the return code
print(f"Job finished with exit code {returncode}")
exit(returncode)
if __name__ == "__main__":
main()