import os import random import torch import numpy as np ALREADY_INITALIZED = False # TODO: Consider torch.distributed.is_initialized() instead def init(rank): nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) nranks = max(1, nranks) is_distributed = (nranks > 1) or ('WORLD_SIZE' in os.environ) global ALREADY_INITALIZED if ALREADY_INITALIZED: return nranks, is_distributed ALREADY_INITALIZED = True if is_distributed and torch.cuda.is_available(): num_gpus = torch.cuda.device_count() print(f'nranks = {nranks} \t num_gpus = {num_gpus} \t device={rank % num_gpus}') torch.cuda.set_device(rank % num_gpus) torch.distributed.init_process_group(backend='nccl', init_method='env://') return nranks, is_distributed def barrier(rank): nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) nranks = max(1, nranks) if rank >= 0 and nranks > 1: torch.distributed.barrier(device_ids=[rank % torch.cuda.device_count()])