from typing import Tuple | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from modules.wrapper import Linear | |
class ProjectionLayer(nn.Module): | |
def __init__(self, d_model: int, vocab_size: int) -> None: | |
super(ProjectionLayer, self).__init__() | |
self.linear = Linear(d_model, vocab_size) | |
def forward(self, x): | |
# (batch, seq_len, d_model) -> (batch, seq_len, vocab_size) | |
return torch.log_softmax(self.linear(x), dim=-1) |