|
|
|
|
|
|
|
|
|
|
|
from torch.autograd import Function |
|
import torch |
|
from torch import nn |
|
|
|
|
|
class GradientReversal(Function): |
|
@staticmethod |
|
def forward(ctx, x, alpha): |
|
ctx.save_for_backward(x, alpha) |
|
return x |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
grad_input = None |
|
_, alpha = ctx.saved_tensors |
|
if ctx.needs_input_grad[0]: |
|
grad_input = -alpha * grad_output |
|
return grad_input, None |
|
|
|
|
|
revgrad = GradientReversal.apply |
|
|
|
|
|
class GradientReversal(nn.Module): |
|
def __init__(self, alpha): |
|
super().__init__() |
|
self.alpha = torch.tensor(alpha, requires_grad=False) |
|
|
|
def forward(self, x): |
|
return revgrad(x, self.alpha) |
|
|