Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from torch.autograd import Function | |
import torch | |
from torch import nn | |
class GradientReversal(Function): | |
def forward(ctx, x, alpha): | |
ctx.save_for_backward(x, alpha) | |
return x | |
def backward(ctx, grad_output): | |
grad_input = None | |
_, alpha = ctx.saved_tensors | |
if ctx.needs_input_grad[0]: | |
grad_input = -alpha * grad_output | |
return grad_input, None | |
revgrad = GradientReversal.apply | |
class GradientReversal(nn.Module): | |
def __init__(self, alpha): | |
super().__init__() | |
self.alpha = torch.tensor(alpha, requires_grad=False) | |
def forward(self, x): | |
return revgrad(x, self.alpha) | |