Spaces:
Runtime error
Runtime error
import os | |
import torch | |
def get_world_size(): | |
"""Find OMPI world size without calling mpi functions | |
:rtype: int | |
""" | |
if os.environ.get('PMI_SIZE') is not None: | |
return int(os.environ.get('PMI_SIZE') or 1) | |
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: | |
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) | |
else: | |
return torch.cuda.device_count() | |
def get_global_rank(): | |
"""Find OMPI world rank without calling mpi functions | |
:rtype: int | |
""" | |
if os.environ.get('PMI_RANK') is not None: | |
return int(os.environ.get('PMI_RANK') or 0) | |
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: | |
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) | |
else: | |
return 0 | |
def get_local_rank(): | |
"""Find OMPI local rank without calling mpi functions | |
:rtype: int | |
""" | |
if os.environ.get('MPI_LOCALRANKID') is not None: | |
return int(os.environ.get('MPI_LOCALRANKID') or 0) | |
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: | |
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) | |
else: | |
return 0 | |
def get_master_ip(): | |
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: | |
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] | |
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: | |
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') | |
else: | |
return "127.0.0.1" | |