|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class ScgptConfig(PretrainedConfig): |
|
model_type = "scgpt" |
|
|
|
def __init__( |
|
self, |
|
ntoken: int, |
|
d_model: int, |
|
nhead: int, |
|
d_hid: int, |
|
nlayers: int, |
|
nlayers_cls: int = 3, |
|
n_cls: int = 1, |
|
vocab: Any = None, |
|
dropout: float = 0.5, |
|
pad_token: str = "<pad>", |
|
pad_value: int = 0, |
|
pert_pad_id: int = 2, |
|
do_mvc: bool = False, |
|
do_dab: bool = False, |
|
use_batch_labels: bool = False, |
|
num_batch_labels: Optional[int] = None, |
|
domain_spec_batchnorm: Union[bool, str] = False, |
|
input_emb_style: str = "continuous", |
|
n_input_bins: Optional[int] = None, |
|
cell_emb_style: str = "cls", |
|
mvc_decoder_style: str = "inner product", |
|
ecs_threshold: float = 0.3, |
|
explicit_zero_prob: bool = False, |
|
use_fast_transformer: bool = False, |
|
fast_transformer_backend: str = "flash", |
|
pre_norm: bool = False, |
|
use_mod: bool = False, |
|
ntokens_mod: Optional[int] = None, |
|
vocab_mod: Optional[Any] = None, |
|
**kwargs, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
self.ntoken = ntoken |
|
self.d_model = d_model |
|
self.nhead = nhead |
|
self.d_hid = d_hid |
|
self.nlayers = nlayers |
|
self.nlayers_cls = nlayers_cls |
|
self.n_cls = n_cls |
|
self.vocab = vocab |
|
self.dropout = dropout |
|
self.pad_token = pad_token |
|
self.pad_value = pad_value |
|
self.pert_pad_id = pert_pad_id |
|
self.do_mvc = do_mvc |
|
self.do_dab = do_dab |
|
self.use_batch_labels = use_batch_labels |
|
self.num_batch_labels = num_batch_labels |
|
self.domain_spec_batchnorm = domain_spec_batchnorm |
|
self.input_emb_style = input_emb_style |
|
self.n_input_bins = n_input_bins |
|
self.cell_emb_style = cell_emb_style |
|
self.mvc_decoder_style = mvc_decoder_style |
|
self.ecs_threshold = ecs_threshold |
|
self.explicit_zero_prob = explicit_zero_prob |
|
self.use_fast_transformer = use_fast_transformer |
|
self.fast_transformer_backend = fast_transformer_backend |
|
self.pre_norm = pre_norm |
|
self.use_mod = use_mod |
|
self.ntokens_mod = ntokens_mod |
|
self.vocab_mod = vocab_mod |
|
super().__init__(**kwargs) |