from modules.commons.common_layers import * import random class MixStyle(nn.Module): """MixStyle. Reference: Zhou et al. Domain Generalization with MixStyle. ICLR 2021. """ def __init__(self, p=0.5, alpha=0.1, eps=1e-6, hidden_size=256): """ Args: p (float): probability of using MixStyle. alpha (float): parameter of the Beta distribution. eps (float): scaling parameter to avoid numerical issues. mix (str): how to mix. """ super().__init__() self.p = p self.beta = torch.distributions.Beta(alpha, alpha) self.eps = eps self.alpha = alpha self._activated = True self.hidden_size = hidden_size self.affine_layer = LinearNorm( hidden_size, 2 * hidden_size, # For both b (bias) g (gain) ) def __repr__(self): return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})' def set_activation_status(self, status=True): self._activated = status def forward(self, x, spk_embed): if not self.training or not self._activated: return x if random.random() > self.p: return x B = x.size(0) mu, sig = torch.mean(x, dim=-1, keepdim=True), torch.std(x, dim=-1, keepdim=True) x_normed = (x - mu) / (sig + 1e-6) # [B, T, H_m] lmda = self.beta.sample((B, 1, 1)) lmda = lmda.to(x.device) # Get Bias and Gain mu1, sig1 = torch.split(self.affine_layer(spk_embed), self.hidden_size, dim=-1) # [B, 1, 2 * H_m] --> 2 * [B, 1, H_m] # MixStyle perm = torch.randperm(B) mu2, sig2 = mu1[perm], sig1[perm] mu_mix = mu1*lmda + mu2 * (1-lmda) sig_mix = sig1*lmda + sig2 * (1-lmda) # Perform Scailing and Shifting return sig_mix * x_normed + mu_mix # [B, T, H_m]