TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from fairseq.distributed import utils
class TPUDistributedDataParallel(nn.Module):
def __init__(self, module, process_group):
super().__init__()
self.module = module
self.process_group = process_group
self.world_size = utils.get_world_size(self.process_group)
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_grads(self):
gradients = []
for p in self.parameters():
if not p.requires_grad:
continue
if p.grad is None:
p.grad = torch.zeros_like(p)
if p.grad.requires_grad:
raise RuntimeError(
"TPUDistributedDataParallel only works with gradients that don't "
"require grad"
)
gradients.append(p.grad)
import torch_xla.core.xla_model as xm
xm.all_reduce(
"sum",
gradients,
scale=1.0 / self.world_size,
groups=self.process_group[1],
)