File size: 926 Bytes
40d1a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import PretrainedConfig
from typing import List


class MoonshineConfig(PretrainedConfig):
    model_type = "moonshine"

    def __init__(
        self,
        dim: int = 288,
        inner_dim: int = None,
        enc_depth: int = 8,
        dec_depth: int = 8,
        n_head: int = 8,
        dec_voc_size: int = 32768,
        enc_ff_swiglu: bool = False,
        dec_ff_swiglu: bool = True,
        **kwargs
    ):
        if inner_dim is None:
            inner_dim = dim
        if inner_dim % n_head != 0:
            raise ValueError("`inner dim` must be divisible by `n_head`")
        self.dim = dim
        self.inner_dim = inner_dim
        self.enc_depth = enc_depth
        self.dec_depth = dec_depth
        self.n_head = n_head
        self.dec_voc_size = dec_voc_size
        self.enc_ff_swiglu = enc_ff_swiglu
        self.dec_ff_swiglu = dec_ff_swiglu
        super().__init__(**kwargs)