File size: 672 Bytes
bc1ada8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch.nn as nn
from torch import Tensor
import numpy as np


class TransformerEmbedding(nn.Module):
    """
    Input Embeddings (section 3.4)

    Embedds words to vectors of size d_

    Args:
        - d_model (int): dimension of model
        - num_embeddings (int): size of the dictionary
    """
    def __init__(self, d_model: int, num_embeddings: int) -> None:
        super(TransformerEmbedding, self).__init__()
        self.sqrt_d_model = np.sqrt(d_model)
        self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=d_model)

    def forward(self, x: Tensor) -> Tensor:
        return self.embedding(x) * self.sqrt_d_model