agemagician commited on
Commit
c977434
1 Parent(s): 1fc1b57

Create configuration_scgpt.py

Browse files
Files changed (1) hide show
  1. configuration_scgpt.py +75 -0
configuration_scgpt.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class ScgptConfig(PretrainedConfig):
6
+ model_type = "scgpt"
7
+
8
+ def __init__(
9
+ self,
10
+ ntoken: int,
11
+ d_model: int,
12
+ nhead: int,
13
+ d_hid: int,
14
+ nlayers: int,
15
+ nlayers_cls: int = 3,
16
+ n_cls: int = 1,
17
+ vocab: Any = None,
18
+ dropout: float = 0.5,
19
+ pad_token: str = "<pad>",
20
+ pad_value: int = 0,
21
+ pert_pad_id: int = 2,
22
+ do_mvc: bool = False,
23
+ do_dab: bool = False,
24
+ use_batch_labels: bool = False,
25
+ num_batch_labels: Optional[int] = None,
26
+ domain_spec_batchnorm: Union[bool, str] = False,
27
+ input_emb_style: str = "continuous",
28
+ n_input_bins: Optional[int] = None,
29
+ cell_emb_style: str = "cls",
30
+ mvc_decoder_style: str = "inner product",
31
+ ecs_threshold: float = 0.3,
32
+ explicit_zero_prob: bool = False,
33
+ use_fast_transformer: bool = False,
34
+ fast_transformer_backend: str = "flash",
35
+ pre_norm: bool = False,
36
+ use_mod: bool = False,
37
+ ntokens_mod: Optional[int] = None,
38
+ vocab_mod: Optional[Any] = None,
39
+ **kwargs,
40
+ ):
41
+ #if block_type not in ["basic", "bottleneck"]:
42
+ # raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
43
+ #if stem_type not in ["", "deep", "deep-tiered"]:
44
+ # raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
45
+
46
+ self.ntoken = ntoken
47
+ self.d_model = d_model
48
+ self.nhead = nhead
49
+ self.d_hid = d_hid
50
+ self.nlayers = nlayers
51
+ self.nlayers_cls = nlayers_cls
52
+ self.n_cls = n_cls
53
+ self.vocab = vocab
54
+ self.dropout = dropout
55
+ self.pad_token = pad_token
56
+ self.pad_value = pad_value
57
+ self.pert_pad_id = pert_pad_id
58
+ self.do_mvc = do_mvc
59
+ self.do_dab = do_dab
60
+ self.use_batch_labels = use_batch_labels
61
+ self.num_batch_labels = num_batch_labels
62
+ self.domain_spec_batchnorm = domain_spec_batchnorm
63
+ self.input_emb_style = input_emb_style
64
+ self.n_input_bins = n_input_bins
65
+ self.cell_emb_style = cell_emb_style
66
+ self.mvc_decoder_style = mvc_decoder_style
67
+ self.ecs_threshold = ecs_threshold
68
+ self.explicit_zero_prob = explicit_zero_prob
69
+ self.use_fast_transformer = use_fast_transformer
70
+ self.fast_transformer_backend = fast_transformer_backend
71
+ self.pre_norm = pre_norm
72
+ self.use_mod = use_mod
73
+ self.ntokens_mod = ntokens_mod
74
+ self.vocab_mod = vocab_mod
75
+ super().__init__(**kwargs)