File size: 3,872 Bytes
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import logging
from dataclasses import dataclass

import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.utils.parametrizations import weight_norm

from ...common import Normalizer

logger = logging.getLogger(__name__)


@dataclass
class IRMAEOutput:
    latent: Tensor  # latent vector
    decoded: Tensor | None  # decoder output, include extra dim


class ResBlock(nn.Sequential):
    def __init__(self, channels, dilations=[1, 2, 4, 8]):
        wn = weight_norm
        super().__init__(
            nn.GroupNorm(32, channels),
            nn.GELU(),
            wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
            nn.GroupNorm(32, channels),
            nn.GELU(),
            wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
            nn.GroupNorm(32, channels),
            nn.GELU(),
            wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
            nn.GroupNorm(32, channels),
            nn.GELU(),
            wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
        )

    def forward(self, x: Tensor):
        return x + super().forward(x)


class IRMAE(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        latent_dim,
        hidden_dim=1024,
        num_irms=4,
    ):
        """
        Args:
            input_dim: input dimension
            output_dim: output dimension
            latent_dim: latent dimension
            hidden_dim: hidden layer dimension
            num_irm_matrics: number of implicit rank minimization matrices
            norm: normalization layer
        """
        self.input_dim = input_dim
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
            *[ResBlock(hidden_dim) for _ in range(4)],
            # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
            *[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
            nn.Tanh(),
        )

        self.decoder = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
            *[ResBlock(hidden_dim) for _ in range(4)],
            nn.Conv1d(hidden_dim, output_dim, 1),
        )

        self.head = nn.Sequential(
            nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
            nn.GELU(),
            nn.Conv1d(hidden_dim, input_dim, 1),
        )

        self.estimator = Normalizer()

    def encode(self, x):
        """
        Args:
            x: (b c t) tensor
        """
        z = self.encoder(x)  # (b c t)
        _ = self.estimator(z)  # Estimate the glboal mean and std of z
        self.stats = {}
        self.stats["z_mean"] = z.mean().item()
        self.stats["z_std"] = z.std().item()
        self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
        self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
        self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
        return z

    def decode(self, z):
        """
        Args:
            z: (b c t) tensor
        """
        return self.decoder(z)

    def forward(self, x, skip_decoding=False):
        """
        Args:
            x: (b c t) tensor
            skip_decoding: if True, skip the decoding step
        """
        z = self.encode(x)  # q(z|x)

        if skip_decoding:
            # This speeds up the training in cfm only mode
            decoded = None
        else:
            decoded = self.decode(z)  # p(x|z)
            predicted = self.head(decoded)
            self.losses = dict(mse=F.mse_loss(predicted, x))

        return IRMAEOutput(latent=z, decoded=decoded)