File size: 672 Bytes
bc1ada8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch.nn as nn
from torch import Tensor
import numpy as np
class TransformerEmbedding(nn.Module):
"""
Input Embeddings (section 3.4)
Embedds words to vectors of size d_
Args:
- d_model (int): dimension of model
- num_embeddings (int): size of the dictionary
"""
def __init__(self, d_model: int, num_embeddings: int) -> None:
super(TransformerEmbedding, self).__init__()
self.sqrt_d_model = np.sqrt(d_model)
self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=d_model)
def forward(self, x: Tensor) -> Tensor:
return self.embedding(x) * self.sqrt_d_model
|