surokpro2 commited on
Commit
d04c320
1 Parent(s): b30a1d3
Files changed (1) hide show
  1. SAE/sae.py +3 -1
SAE/sae.py CHANGED
@@ -8,6 +8,7 @@ import torch.nn as nn
8
  import os
9
  import json
10
  import spaces
 
11
 
12
  class SparseAutoencoder(nn.Module):
13
  """
@@ -106,7 +107,6 @@ class SparseAutoencoder(nn.Module):
106
 
107
  @spaces.GPU
108
  def encode(self, x):
109
- self = self.to('cuda')
110
  x = x - self.pre_bias
111
  latents_pre_act = self.encoder(x) + self.latent_bias
112
 
@@ -122,6 +122,8 @@ class SparseAutoencoder(nn.Module):
122
  return latents
123
 
124
  def forward(self, x):
 
 
125
  x = x - self.pre_bias
126
  latents_pre_act = self.encoder(x) + self.latent_bias
127
  vals, inds = torch.topk(
 
8
  import os
9
  import json
10
  import spaces
11
+ import logging
12
 
13
  class SparseAutoencoder(nn.Module):
14
  """
 
107
 
108
  @spaces.GPU
109
  def encode(self, x):
 
110
  x = x - self.pre_bias
111
  latents_pre_act = self.encoder(x) + self.latent_bias
112
 
 
122
  return latents
123
 
124
  def forward(self, x):
125
+ print(x.device)
126
+ print(self.pre_bias.device)
127
  x = x - self.pre_bias
128
  latents_pre_act = self.encoder(x) + self.latent_bias
129
  vals, inds = torch.topk(