tts-rvc-autopst / onmt_modules /position_ffn.py
jonathanjordan21's picture
Upload folder using huggingface_hub (#1)
7ce5feb verified
"""Position feed-forward network from "Attention is All You Need"."""
import torch.nn as nn
class PositionwiseFeedForward(nn.Module):
""" A two-layer Feed-Forward-Network with residual layer norm.
Args:
d_model (int): the size of input for the first-layer of the FFN.
d_ff (int): the hidden layer size of the second-layer
of the FNN.
dropout (float): dropout probability in :math:`[0, 1)`.
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.dropout_1 = nn.Dropout(dropout)
self.relu = nn.ReLU()
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x):
"""Layer definition.
Args:
x: ``(batch_size, input_len, model_dim)``
Returns:
(FloatTensor): Output ``(batch_size, input_len, model_dim)``.
"""
inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x))))
output = self.dropout_2(self.w_2(inter))
return output + x
def update_dropout(self, dropout):
self.dropout_1.p = dropout
self.dropout_2.p = dropout