VideoCrafter / lvdm /utils /dist_utils.py
imjunaidafzal's picture
Duplicate from VideoCrafter/VideoCrafter
153e804
raw
history blame contribute delete
603 Bytes
import torch
import torch.distributed as dist
def setup_dist(local_rank):
if dist.is_initialized():
return
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(
'nccl',
init_method='env://'
)
def gather_data(data, return_np=True):
''' gather data from multiple processes to one list '''
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
dist.all_gather(data_list, data) # gather not supported with NCCL
if return_np:
data_list = [data.cpu().numpy() for data in data_list]
return data_list