Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from enum import Enum | |
import os | |
from pathlib import Path | |
from typing import Any, Dict, Optional | |
class ClusterType(Enum): | |
AWS = "aws" | |
FAIR = "fair" | |
RSC = "rsc" | |
def _guess_cluster_type() -> ClusterType: | |
uname = os.uname() | |
if uname.sysname == "Linux": | |
if uname.release.endswith("-aws"): | |
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" | |
return ClusterType.AWS | |
elif uname.nodename.startswith("rsc"): | |
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" | |
return ClusterType.RSC | |
return ClusterType.FAIR | |
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: | |
if cluster_type is None: | |
return _guess_cluster_type() | |
return cluster_type | |
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: | |
cluster_type = get_cluster_type(cluster_type) | |
if cluster_type is None: | |
return None | |
CHECKPOINT_DIRNAMES = { | |
ClusterType.AWS: "checkpoints", | |
ClusterType.FAIR: "checkpoint", | |
ClusterType.RSC: "checkpoint/dino", | |
} | |
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] | |
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: | |
checkpoint_path = get_checkpoint_path(cluster_type) | |
if checkpoint_path is None: | |
return None | |
username = os.environ.get("USER") | |
assert username is not None | |
return checkpoint_path / username | |
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: | |
cluster_type = get_cluster_type(cluster_type) | |
if cluster_type is None: | |
return None | |
SLURM_PARTITIONS = { | |
ClusterType.AWS: "learnlab", | |
ClusterType.FAIR: "learnlab", | |
ClusterType.RSC: "learn", | |
} | |
return SLURM_PARTITIONS[cluster_type] | |
def get_slurm_executor_parameters( | |
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs | |
) -> Dict[str, Any]: | |
# create default parameters | |
params = { | |
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html | |
"gpus_per_node": num_gpus_per_node, | |
"tasks_per_node": num_gpus_per_node, # one task per GPU | |
"cpus_per_task": 10, | |
"nodes": nodes, | |
"slurm_partition": get_slurm_partition(cluster_type), | |
} | |
# apply cluster-specific adjustments | |
cluster_type = get_cluster_type(cluster_type) | |
if cluster_type == ClusterType.AWS: | |
params["cpus_per_task"] = 12 | |
del params["mem_gb"] | |
elif cluster_type == ClusterType.RSC: | |
params["cpus_per_task"] = 12 | |
# set additional parameters / apply overrides | |
params.update(kwargs) | |
return params | |