|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class MoonshineConfig(PretrainedConfig): |
|
model_type = "moonshine" |
|
|
|
def __init__( |
|
self, |
|
dim: int = 288, |
|
inner_dim: int = None, |
|
enc_depth: int = 8, |
|
dec_depth: int = 8, |
|
n_head: int = 8, |
|
dec_voc_size: int = 32768, |
|
enc_ff_swiglu: bool = False, |
|
dec_ff_swiglu: bool = True, |
|
**kwargs |
|
): |
|
if inner_dim is None: |
|
inner_dim = dim |
|
if inner_dim % n_head != 0: |
|
raise ValueError("`inner dim` must be divisible by `n_head`") |
|
self.dim = dim |
|
self.inner_dim = inner_dim |
|
self.enc_depth = enc_depth |
|
self.dec_depth = dec_depth |
|
self.n_head = n_head |
|
self.dec_voc_size = dec_voc_size |
|
self.enc_ff_swiglu = enc_ff_swiglu |
|
self.dec_ff_swiglu = dec_ff_swiglu |
|
super().__init__(**kwargs) |
|
|