import math | |
from typing import Optional , Union | |
from transformers import PretrainedConfig | |
class MambaConfig(PretrainedConfig): | |
model_type = "mamba" | |
def __init__( | |
self, | |
vocab_size=50277, | |
d_state=16, | |
d_model=2560, | |
d_conv=4, | |
expand=2, | |
conv_bias=True, | |
bias=False, | |
n_layer=64, | |
dt_rank: Union[int, str] = "auto", | |
pad_vocab_size_multiple=8, | |
initializer_range=0.02, | |
**kwargs, | |
): | |
self.vocab_size = vocab_size | |
self.n_layer= n_layer | |
self.conv_bias = conv_bias | |
self.expand = expand | |
self.pad_vocab_size_multiple = pad_vocab_size_multiple | |
self.d_conv = d_conv | |
self.d_model = d_model | |
self.d_state = d_state | |
self.d_inner = int(self.expand * self.d_model) | |
self.dt_rank = dt_rank | |
self.initializer_range = initializer_range | |
self.bias = bias | |
if self.dt_rank == 'auto': | |
self.dt_rank = math.ceil(self.d_model / 16) | |
if self.vocab_size % self.pad_vocab_size_multiple != 0: | |
self.vocab_size += (self.pad_vocab_size_multiple | |
- self.vocab_size % self.pad_vocab_size_multiple) | |
super().__init__( | |
**kwargs, | |
) |