File size: 1,802 Bytes
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging

import torch
from torch import Tensor, nn

logger = logging.getLogger(__name__)


class Normalizer(nn.Module):
    def __init__(self, momentum=0.01, eps=1e-9):
        super().__init__()
        self.momentum = momentum
        self.eps = eps
        self.running_mean_unsafe: Tensor
        self.running_var_unsafe: Tensor
        self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
        self.register_buffer("running_var_unsafe", torch.full([], torch.nan))

    @property
    def started(self):
        return not torch.isnan(self.running_mean_unsafe)

    @property
    def running_mean(self):
        if not self.started:
            return torch.zeros_like(self.running_mean_unsafe)
        return self.running_mean_unsafe

    @property
    def running_std(self):
        if not self.started:
            return torch.ones_like(self.running_var_unsafe)
        return (self.running_var_unsafe + self.eps).sqrt()

    @torch.no_grad()
    def _ema(self, a: Tensor, x: Tensor):
        return (1 - self.momentum) * a + self.momentum * x

    def update_(self, x):
        if not self.started:
            self.running_mean_unsafe = x.mean()
            self.running_var_unsafe = x.var()
        else:
            self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
            self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())

    def forward(self, x: Tensor, update=True):
        if self.training and update:
            self.update_(x)
        self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
        x = (x - self.running_mean) / self.running_std
        return x

    def inverse(self, x: Tensor):
        return x * self.running_std + self.running_mean