File size: 926 Bytes
40d1a51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from transformers import PretrainedConfig
from typing import List
class MoonshineConfig(PretrainedConfig):
model_type = "moonshine"
def __init__(
self,
dim: int = 288,
inner_dim: int = None,
enc_depth: int = 8,
dec_depth: int = 8,
n_head: int = 8,
dec_voc_size: int = 32768,
enc_ff_swiglu: bool = False,
dec_ff_swiglu: bool = True,
**kwargs
):
if inner_dim is None:
inner_dim = dim
if inner_dim % n_head != 0:
raise ValueError("`inner dim` must be divisible by `n_head`")
self.dim = dim
self.inner_dim = inner_dim
self.enc_depth = enc_depth
self.dec_depth = dec_depth
self.n_head = n_head
self.dec_voc_size = dec_voc_size
self.enc_ff_swiglu = enc_ff_swiglu
self.dec_ff_swiglu = dec_ff_swiglu
super().__init__(**kwargs)
|