Spaces:
Runtime error
Runtime error
import torch | |
import torch.distributed as dist | |
def setup_dist(local_rank): | |
if dist.is_initialized(): | |
return | |
torch.cuda.set_device(local_rank) | |
torch.distributed.init_process_group( | |
'nccl', | |
init_method='env://' | |
) | |
def gather_data(data, return_np=True): | |
''' gather data from multiple processes to one list ''' | |
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] | |
dist.all_gather(data_list, data) # gather not supported with NCCL | |
if return_np: | |
data_list = [data.cpu().numpy() for data in data_list] | |
return data_list | |