File size: 6,101 Bytes
63a794c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10845f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import os
import glob

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class FastAutoencoder(nn.Module):
    def __init__(self, n_dirs: int, d_model: int, k: int, auxk: int, multik: int, dead_steps_threshold: int = 266):
        super().__init__()
        self.n_dirs = n_dirs
        self.d_model = d_model
        self.k = k
        self.auxk = auxk
        self.multik = multik
        self.dead_steps_threshold = dead_steps_threshold

        self.encoder = nn.Linear(d_model, n_dirs, bias=False)
        self.decoder = nn.Linear(n_dirs, d_model, bias=False)

        self.pre_bias = nn.Parameter(torch.zeros(d_model))
        self.latent_bias = nn.Parameter(torch.zeros(n_dirs))

        self.stats_last_nonzero = torch.zeros(n_dirs, dtype=torch.long, device=device)

    def forward(self, x):
        x = x - self.pre_bias
        latents_pre_act = self.encoder(x) + self.latent_bias

        # Main top-k selection
        topk_values, topk_indices = torch.topk(latents_pre_act, k=self.k, dim=-1)
        topk_values = F.relu(topk_values)
        multik_values, multik_indices = torch.topk(latents_pre_act, k=4*self.k, dim=-1)
        multik_values = F.relu(multik_values)

        latents = torch.zeros_like(latents_pre_act)
        latents.scatter_(-1, topk_indices, topk_values)
        multik_latents = torch.zeros_like(latents_pre_act)
        multik_latents.scatter_(-1, multik_indices, multik_values)

        # Update stats_last_nonzero
        self.stats_last_nonzero += 1
        self.stats_last_nonzero.scatter_(0, topk_indices.unique(), 0)

        recons = self.decoder(latents) + self.pre_bias
        multik_recons = self.decoder(multik_latents) + self.pre_bias

        # AuxK
        if self.auxk is not None:
            # Create dead latents mask
            dead_mask = (self.stats_last_nonzero > self.dead_steps_threshold).float()
            
            # Apply mask to latents_pre_act
            dead_latents_pre_act = latents_pre_act * dead_mask
            
            # Select top-k_aux from dead latents
            auxk_values, auxk_indices = torch.topk(dead_latents_pre_act, k=self.auxk, dim=-1)
            auxk_values = F.relu(auxk_values)

        else:
            auxk_values, auxk_indices = None, None

        return recons, {
            "topk_indices": topk_indices,
            "topk_values": topk_values,
            "multik_indices": multik_indices,
            "multik_values": multik_values,
            "multik_recons": multik_recons,
            "auxk_indices": auxk_indices,
            "auxk_values": auxk_values,
            "latents_pre_act": latents_pre_act,
            "latents_post_act": latents,
        }

    def decode_sparse(self, indices, values):
        latents = torch.zeros(self.n_dirs, device=indices.device)
        latents.scatter_(-1, indices, values)
        return self.decoder(latents) + self.pre_bias

    # def decode_sparse(self, indices, values):
    #     latents = torch.zeros(1, self.n_dirs, device=indices.device, dtype=torch.float32)
    #     latents.scatter_(-1, indices.unsqueeze(0), values.unsqueeze(0))
    #     return self.decoder(latents.squeeze(0)) + self.pre_bias
    
    def print_tensor_info(self, tensor, name):
        print(f"{name} - Shape: {tensor.shape}, Dtype: {tensor.dtype}, Device: {tensor.device}")

    def decode_clamp(self, latents, clamp):
        topk_values, topk_indices = torch.topk(latents, k = 64, dim=-1)
        topk_values = F.relu(topk_values)
        latents = torch.zeros_like(latents)
        latents.scatter_(-1, topk_indices, topk_values)
        # multiply latents by clamp, which is 1D but has has the same size as each latent vector
        latents = latents * clamp
        
        return self.decoder(latents) + self.pre_bias
    
    def decode_at_k(self, latents, k):
        topk_values, topk_indices = torch.topk(latents, k=k, dim=-1)
        topk_values = F.relu(topk_values)
        latents = torch.zeros_like(latents)
        latents.scatter_(-1, topk_indices, topk_values)
        
        return self.decoder(latents) + self.pre_bias

def unit_norm_decoder_(autoencoder: FastAutoencoder) -> None:
    with torch.no_grad():
        autoencoder.decoder.weight.div_(autoencoder.decoder.weight.norm(dim=0, keepdim=True))

def unit_norm_decoder_grad_adjustment_(autoencoder: FastAutoencoder) -> None:
    if autoencoder.decoder.weight.grad is not None:
        with torch.no_grad():
            proj = torch.sum(autoencoder.decoder.weight * autoencoder.decoder.weight.grad, dim=0, keepdim=True)
            autoencoder.decoder.weight.grad.sub_(proj * autoencoder.decoder.weight)

def mse(output, target):
    return F.mse_loss(output, target)

def normalized_mse(recon, xs):
    return mse(recon, xs) / mse(xs.mean(dim=0, keepdim=True).expand_as(xs), xs)

def loss_fn(ae, x, recons, info, auxk_coef, multik_coef):
    recons_loss = normalized_mse(recons, x)
    recons_loss += multik_coef * normalized_mse(info["multik_recons"], x)
    
    if ae.auxk is not None:
        e = x - recons.detach()  # reconstruction error
        auxk_latents = torch.zeros_like(info["latents_pre_act"])
        auxk_latents.scatter_(-1, info["auxk_indices"], info["auxk_values"])
        e_hat = ae.decoder(auxk_latents)  # reconstruction of error using dead latents
        auxk_loss = normalized_mse(e_hat, e)
        total_loss = recons_loss + auxk_coef * auxk_loss
    else:
        auxk_loss = torch.tensor(0.0, device=device)
        total_loss = recons_loss
    
    return total_loss, recons_loss, auxk_loss

def init_from_data_(ae, data_sample):
    # set pre_bias to median of data
    ae.pre_bias.data = torch.median(data_sample, dim=0).values
    nn.init.xavier_uniform_(ae.decoder.weight)

    # decoder is unit norm
    unit_norm_decoder_(ae)

    # encoder as transpose of decoder
    ae.encoder.weight.data = ae.decoder.weight.t().clone()

    nn.init.zeros_(ae.latent_bias)