File size: 2,543 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class Transformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        dropout = self.cfg.dropout
        nhead = self.cfg.n_heads
        nlayers = self.cfg.n_layers
        input_dim = self.cfg.input_dim
        output_dim = self.cfg.output_dim

        d_model = input_dim
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(
            d_model, nhead, dropout=dropout, batch_first=True
        )
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

        self.output_mlp = nn.Linear(d_model, output_dim)

    def forward(self, x, mask=None):
        """
        Args:
            x: (N, seq_len, input_dim)
        Returns:
            output: (N, seq_len, output_dim)
        """
        # (N, seq_len, d_model)
        src = self.pos_encoder(x)
        # model_stats["pos_embedding"] = x
        # (N, seq_len, d_model)
        output = self.transformer_encoder(src)
        # (N, seq_len, output_dim)
        output = self.output_mlp(output)
        return output


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        # Assume that x is (seq_len, N, d)
        # pe = torch.zeros(max_len, 1, d_model)
        # pe[:, 0, 0::2] = torch.sin(position * div_term)
        # pe[:, 0, 1::2] = torch.cos(position * div_term)

        # Assume that x in (N, seq_len, d)
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)

        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [N, seq_len, d]
        """
        # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
        # x = x + self.pe[: x.size(0)]

        # Now: self.pe is (1, max_len, d)
        x = x + self.pe[:, : x.size(1), :]

        return self.dropout(x)