import torch | |
import torch.nn as nn | |
def log(t, eps = 1e-20): | |
return torch.log(t.clamp(min = eps)) | |
def gumbel_noise(t): | |
noise = torch.zeros_like(t).uniform_(0, 1) | |
return -log(-log(noise)) | |
class NoisyGate(nn.Module): | |
def __init__(self, hidden_dim, num_experts, noise_mult=1.0, bias=False): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
self.num_experts = num_experts | |
self.noise_mult = noise_mult | |
self.bias = bias | |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=self.bias) | |
def forward(self, x): | |
x = self.gate(x) | |
noise = gumbel_noise(x) | |
out = x + noise * self.noise_mult | |
return out | |