homemade_lo_vi / layers /projection_layer.py
moiduy04's picture
Upload 18 files
bc1ada8
raw
history blame
469 Bytes
from typing import Tuple
import torch
import torch.nn as nn
from torch import Tensor
from modules.wrapper import Linear
class ProjectionLayer(nn.Module):
def __init__(self, d_model: int, vocab_size: int) -> None:
super(ProjectionLayer, self).__init__()
self.linear = Linear(d_model, vocab_size)
def forward(self, x):
# (batch, seq_len, d_model) -> (batch, seq_len, vocab_size)
return torch.log_softmax(self.linear(x), dim=-1)