Spaces:
Sleeping
Sleeping
File size: 1,455 Bytes
8cd00a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
from dataclasses import dataclass, field
@dataclass
class SAETrainingConfig:
d_model: int
n_dirs: int
k: int
block_name: str
bs: int
save_path_base: str
auxk: int = 256
lr: float = 1e-4
eps: float = 6.25e-10
dead_toks_threshold: int = 10_000_000
auxk_coef: float = 1/32
@property
def sae_name(self):
return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
@property
def save_path(self):
return f'/dlabscratch1/surkov/sae_models/{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
@dataclass
class Config:
saes: list[SAETrainingConfig]
paths_to_latents: list[str]
log_interval: int
save_interval: int
bs: int
block_name: str
wandb_project: str = 'sdxl_sae_train'
wandb_name: str = 'multiple_sae'
def __init__(self, cfg_json):
self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base'])
for sae_cfg in cfg_json['sae_configs']]
self.save_path_base = cfg_json['save_path_base']
self.paths_to_latents = cfg_json['paths_to_latents']
self.log_interval = cfg_json['log_interval']
self.save_interval = cfg_json['save_interval']
self.bs = cfg_json['bs']
self.block_name = cfg_json['block_name'] |