from transformers import PretrainedConfig from typing import List class MoonshineConfig(PretrainedConfig): model_type = "moonshine" def __init__( self, dim: int = 288, inner_dim: int = None, enc_depth: int = 8, dec_depth: int = 8, n_head: int = 8, dec_voc_size: int = 32768, enc_ff_swiglu: bool = False, dec_ff_swiglu: bool = True, **kwargs ): if inner_dim is None: inner_dim = dim if inner_dim % n_head != 0: raise ValueError("`inner dim` must be divisible by `n_head`") self.dim = dim self.inner_dim = inner_dim self.enc_depth = enc_depth self.dec_depth = dec_depth self.n_head = n_head self.dec_voc_size = dec_voc_size self.enc_ff_swiglu = enc_ff_swiglu self.dec_ff_swiglu = dec_ff_swiglu super().__init__(**kwargs)