Spaces:
Build error
Build error
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" | |
import torch | |
#---------------------------------------------------------------------------- | |
def fma(a, b, c): # => a * b + c | |
return _FusedMultiplyAdd.apply(a, b, c) | |
#---------------------------------------------------------------------------- | |
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c | |
def forward(ctx, a, b, c): # pylint: disable=arguments-differ | |
out = torch.addcmul(c, a, b) | |
ctx.save_for_backward(a, b) | |
ctx.c_shape = c.shape | |
return out | |
def backward(ctx, dout): # pylint: disable=arguments-differ | |
a, b = ctx.saved_tensors | |
c_shape = ctx.c_shape | |
da = None | |
db = None | |
dc = None | |
if ctx.needs_input_grad[0]: | |
da = _unbroadcast(dout * b, a.shape) | |
if ctx.needs_input_grad[1]: | |
db = _unbroadcast(dout * a, b.shape) | |
if ctx.needs_input_grad[2]: | |
dc = _unbroadcast(dout, c_shape) | |
return da, db, dc | |
#---------------------------------------------------------------------------- | |
def _unbroadcast(x, shape): | |
extra_dims = x.ndim - len(shape) | |
assert extra_dims >= 0 | |
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] | |
if len(dim): | |
x = x.sum(dim=dim, keepdim=True) | |
if extra_dims: | |
x = x.reshape(-1, *x.shape[extra_dims+1:]) | |
assert x.shape == shape | |
return x | |
#---------------------------------------------------------------------------- | |