|
from transformers import PretrainedConfig |
|
|
|
class FlowformerConfig(PretrainedConfig): |
|
r""" |
|
This is the configuration class to store the configuration of a `Flowformer`. It is used to instantiate an |
|
Flowformer model according to the specified arguments, defining the model architecture. Instantiating a configuration |
|
with the defaults will yield a similar configuration to that of out model for ALL data (https://arxiv.org/abs/2108.10072). |
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
|
documentation from [`PretrainedConfig`] for more information. |
|
Args: |
|
dim_hidden (`int`, *optional*, defaults to 32): |
|
The dimensionality of the hidden states. dim_hidden must be divisible by num_heads i.e. dim_hidden%num_heads = 0. |
|
num_heads (`int`, *optional*, defaults to 4): |
|
The number of attention heads. |
|
num_inds (`int`, *optional*, defaults to 32): |
|
The number of inducing points. |
|
hidden_layers (`int`, *optional*, defaults to 3): |
|
The number of hidden layers. |
|
layer_norm (`bool`, *optional*, defaults to True): |
|
Whether to apply layer normalization. |
|
dim_input (`int`, *optional*, defaults to 11): |
|
The dimensionality of the input. |
|
markers (`list`, *optional*, defaults to ["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"]): |
|
The list of markers. |
|
""" |
|
def __init__(self, |
|
dim_hidden: int=32, |
|
num_heads: int=4, |
|
num_inds: int=16, |
|
hidden_layers: int=3, |
|
layer_norm: bool=True, |
|
dim_input: int=11, |
|
markers: list=["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"], |
|
**kwargs |
|
): |
|
assert dim_input == len(markers), "dim_input must be equal to the number of markers" |
|
|
|
self.dim_hidden = dim_hidden |
|
self.num_heads = num_heads |
|
self.num_inds = num_inds |
|
self.hidden_layers = hidden_layers |
|
self.layer_norm = layer_norm |
|
self.dim_input = dim_input |
|
self.markers = markers |
|
super().__init__(**kwargs) |