File size: 1,733 Bytes
ba54498 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
# pip install transformers
from transformers import PretrainedConfig
from typing import List
'''
newtwork_config = {
"epochs": 150,
"batch_size": 250,
"n_steps": 16, # timestep
"dataset": "CAPS",
"in_channels": 1,
"data_path": "./data",
"lr": 0.001,
"n_class": 10,
"latent_dim": 128,
"input_size": 32,
"model": "FSVAE" ,# FSVAE or FSVAE_large
"k": 20, # multiplier of channel
"scheduled": True, # whether to apply scheduled sampling
"loss_func": 'kld', # mmd or kld
"accum_iter" : 1,
"devices": [0],
}
hidden_dims = [32, 64, 128, 256]
'''
class FSAEConfig(PretrainedConfig):
model_type = "fsae"
def __init__(
self,
in_channels: int = 1,
hidden_dims : List[int] = [32, 64, 128, 256],
k : int = 20,
n_steps : int = 16,
latent_dim : int = 128,
scheduled : bool = True,
# loss_func : str = "kld",
dt:float = 5,
a:float = 0.25,
aa: float = 0.5,
Vth : float = 0.2, # しきい値電位
tau : float = 0.25,
**kwargs,
):
# if block_type not in ["basic", "bottleneck"]:
# raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
# if stem_type not in ["", "deep", "deep-tiered"]:
# raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
self.in_channels = in_channels
self.hidden_dims = hidden_dims
self.k = k
self.n_steps = n_steps
self.latent_dim = latent_dim
self.scheduled = scheduled
self.dt = dt
self.a = a
self.aa = aa
self.Vth = Vth
self.tau = tau
super().__init__(**kwargs) |