File size: 1,733 Bytes
ba54498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# pip install transformers
from transformers import PretrainedConfig
from typing import List


'''
newtwork_config = {
  "epochs": 150,
  "batch_size": 250,
  "n_steps": 16, # timestep
  "dataset": "CAPS",
  "in_channels": 1,
  "data_path": "./data",
  "lr": 0.001,
  "n_class": 10,
  "latent_dim": 128,
  "input_size": 32,
  "model": "FSVAE" ,# FSVAE or  FSVAE_large
  "k": 20, # multiplier of channel
  "scheduled": True, # whether to apply scheduled sampling
  "loss_func": 'kld', # mmd or kld
  "accum_iter" : 1,
  "devices": [0],
}

hidden_dims = [32, 64, 128, 256]

'''

class FSAEConfig(PretrainedConfig):
    model_type = "fsae"

    def __init__(
        self,
        in_channels: int = 1,
        hidden_dims : List[int] = [32, 64, 128, 256],
        k : int = 20,
        n_steps : int = 16,
        latent_dim : int = 128,
        scheduled : bool = True,
        # loss_func : str = "kld",
        dt:float = 5,
        a:float = 0.25,
        aa: float = 0.5,
        Vth : float = 0.2, # しきい値電位
        tau : float = 0.25,
        **kwargs,
    ):
        # if block_type not in ["basic", "bottleneck"]:
        #     raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
        # if stem_type not in ["", "deep", "deep-tiered"]:
        #     raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")

        self.in_channels = in_channels
        self.hidden_dims  = hidden_dims
        self.k = k
        self.n_steps = n_steps
        self.latent_dim = latent_dim
        self.scheduled = scheduled
        self.dt = dt
        self.a = a
        self.aa = aa
        self.Vth = Vth
        self.tau = tau
        super().__init__(**kwargs)