File size: 3,555 Bytes
0a3525d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from .wavenet import WaveNet


class ReferenceEncoder(WaveNet):
    def __init__(
        self,
        input_channels: Optional[int] = None,
        output_channels: Optional[int] = None,
        residual_channels: int = 512,
        residual_layers: int = 20,
        dilation_cycle: Optional[int] = 4,
        num_heads: int = 8,
        latent_len: int = 4,
    ):
        super().__init__(
            input_channels=input_channels,
            residual_channels=residual_channels,
            residual_layers=residual_layers,
            dilation_cycle=dilation_cycle,
        )

        self.head_dim = residual_channels // num_heads
        self.num_heads = num_heads

        self.latent_len = latent_len
        self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))

        self.q = nn.Linear(residual_channels, residual_channels, bias=True)
        self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
        self.q_norm = nn.LayerNorm(self.head_dim)
        self.k_norm = nn.LayerNorm(self.head_dim)
        self.proj = nn.Linear(residual_channels, residual_channels)
        self.proj_drop = nn.Dropout(0.1)

        self.norm = nn.LayerNorm(residual_channels)
        self.mlp = nn.Sequential(
            nn.Linear(residual_channels, residual_channels * 4),
            nn.SiLU(),
            nn.Linear(residual_channels * 4, residual_channels),
        )
        self.output_projection_attn = nn.Linear(residual_channels, output_channels)

        torch.nn.init.trunc_normal_(self.latent, std=0.02)
        self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)

    def forward(self, x, attn_mask=None):
        x = super().forward(x).mT
        B, N, C = x.shape

        # Calculate mask
        if attn_mask is not None:
            assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool

            attn_mask = attn_mask[:, None, None, :].expand(
                B, self.num_heads, self.latent_len, N
            )

        q_latent = self.latent.expand(B, -1, -1)
        q = (
            self.q(q_latent)
            .reshape(B, self.latent_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )

        kv = (
            self.kv(x)
            .reshape(B, N, 2, self.num_heads, self.head_dim)
            .permute(2, 0, 3, 1, 4)
        )
        k, v = kv.unbind(0)

        q, k = self.q_norm(q), self.k_norm(k)
        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)

        x = x.transpose(1, 2).reshape(B, self.latent_len, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        x = x + self.mlp(self.norm(x))
        x = self.output_projection_attn(x)
        x = x.mean(1)

        return x


if __name__ == "__main__":
    with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
        model = ReferenceEncoder(
            input_channels=128,
            output_channels=64,
            residual_channels=384,
            residual_layers=20,
            dilation_cycle=4,
            num_heads=8,
        )
        x = torch.randn(4, 128, 64)
        mask = torch.ones(4, 64, dtype=torch.bool)
        y = model(x, mask)
        print(y.shape)
        loss = F.mse_loss(y, torch.randn(4, 64))
        loss.backward()