Spaces:
Runtime error
Runtime error
File size: 1,050 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import os
import random
import torch
import numpy as np
ALREADY_INITALIZED = False
# TODO: Consider torch.distributed.is_initialized() instead
def init(rank):
nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE'])
nranks = max(1, nranks)
is_distributed = (nranks > 1) or ('WORLD_SIZE' in os.environ)
global ALREADY_INITALIZED
if ALREADY_INITALIZED:
return nranks, is_distributed
ALREADY_INITALIZED = True
if is_distributed and torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
print(f'nranks = {nranks} \t num_gpus = {num_gpus} \t device={rank % num_gpus}')
torch.cuda.set_device(rank % num_gpus)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
return nranks, is_distributed
def barrier(rank):
nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE'])
nranks = max(1, nranks)
if rank >= 0 and nranks > 1:
torch.distributed.barrier(device_ids=[rank % torch.cuda.device_count()])
|