File size: 7,645 Bytes
2776201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional, Type, Union

import torch
import yaml
from typing_extensions import Self

import litgpt.model
from litgpt.utils import find_multiple


@dataclass
class Config:
    name: str = ""
    hf_config: dict = field(default_factory=dict)
    scale_embeddings: bool = False
    block_size: int = 4096
    vocab_size: int = 50254
    padding_multiple: int = 512
    padded_vocab_size: Optional[int] = None
    n_layer: int = 16
    n_head: int = 32
    head_size: Optional[int] = None
    n_embd: int = 4096
    rotary_percentage: float = 0.25
    parallel_residual: bool = True
    bias: bool = True
    lm_head_bias: bool = False
    # to use multi-head attention (MHA), set this to `n_head` (default)
    # to use multi-query attention (MQA), set this to 1
    # to use grouped-query attention (GQA), set this to a value in between
    # Example with `n_head=4`
    # β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”     β”Œβ”€β”€β”€β”    β”Œβ”€β”€β”€β”             β”Œβ”€β”€β”€β”
    # β”‚ v β”‚β”‚ v β”‚β”‚ v β”‚β”‚ v β”‚     β”‚ v β”‚    β”‚ v β”‚             β”‚ v β”‚
    # β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜     β””β”€β”€β”€β”˜    β””β”€β”€β”€β”˜             β””β”€β”€β”€β”˜
    #   β”‚    β”‚    β”‚    β”‚         β”‚        β”‚                 β”‚
    # β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”     β”Œβ”€β”€β”€β”    β”Œβ”€β”€β”€β”             β”Œβ”€β”€β”€β”
    # β”‚ k β”‚β”‚ k β”‚β”‚ k β”‚β”‚ k β”‚     β”‚ k β”‚    β”‚ k β”‚             β”‚ k β”‚
    # β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜     β””β”€β”€β”€β”˜    β””β”€β”€β”€β”˜             β””β”€β”€β”€β”˜
    #   β”‚    β”‚    β”‚    β”‚      β”Œβ”€β”€β”΄β”€β”€β”  β”Œβ”€β”€β”΄β”€β”€β”      β”Œβ”€β”€β”€β”€β”¬β”€β”€β”΄β”€β”¬β”€β”€β”€β”€β”
    # β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”  β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”  β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”β”Œβ”€β”€β”€β”
    # β”‚ q β”‚β”‚ q β”‚β”‚ q β”‚β”‚ q β”‚  β”‚ q β”‚β”‚ q β”‚β”‚ q β”‚β”‚ q β”‚  β”‚ q β”‚β”‚ q β”‚β”‚ q β”‚β”‚ q β”‚
    # β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜  β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜  β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜β””β”€β”€β”€β”˜
    # ◀──────────────────▢  ◀──────────────────▢  ◀──────────────────▢
    #         MHA                    GQA                   MQA
    #   n_query_groups=4       n_query_groups=2      n_query_groups=1
    #
    # credit https://arxiv.org/pdf/2305.13245.pdf
    n_query_groups: Optional[int] = None
    shared_attention_norm: bool = False
    norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
    norm_eps: float = 1e-5
    mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
        "GptNeoxMLP"
    )
    gelu_approximate: str = "none"
    intermediate_size: Optional[int] = None
    rope_condense_ratio: int = 1
    rope_base: int = 10000
    n_expert: int = 0
    n_expert_per_token: int = 0

    add_qkv_bias: Optional[bool] = None
    prompt_vocab_size: Optional[int] = None
    attn_dropout: float = 0.0
    pos_type: str = "rope"
    force_align: bool = False
    use_pretrain_phoneme_emb: bool = False
    tie_word_embeddings: bool = False

    # setting for mini-omni
    text_vocab_size:int = 152000
    cat_audio_vocab_size: int = 29120
    audio_vocab_size: int = 4160
    whisper_adapter_dim: int = 768

    post_adapter: bool = False
    post_adapter_layers: int = 6
    asr_adapter: str = "llamamlp"

    def __post_init__(self):
        if not self.name:
            self.name = self.hf_config.get("name", self.name)

        if self.head_size is None:
            assert self.n_embd % self.n_head == 0
            self.head_size = self.n_embd // self.n_head

        # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
        if self.padded_vocab_size is None:
            self.padded_vocab_size = find_multiple(
                self.vocab_size, self.padding_multiple
            )
        else:
            # vocab size shouldn't be larger than padded vocab size
            self.vocab_size = min(self.vocab_size, self.padded_vocab_size)

        # compute the number of query groups
        if self.n_query_groups is not None:
            assert self.n_head % self.n_query_groups == 0
        else:
            self.n_query_groups = self.n_head

        # compute the intermediate size for MLP if not set
        if self.intermediate_size is None:
            if self.mlp_class_name == "LLaMAMLP":
                raise ValueError(
                    f"The config {self.name!r}, needs to set the `intermediate_size`"
                )
            self.intermediate_size = 4 * self.n_embd

        self.rope_n_elem = int(self.rotary_percentage * self.head_size)

        if self.add_qkv_bias is None:
            self.add_qkv_bias = self.bias

    @classmethod
    def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
        if name not in name_to_config:
            # search through all `config['hf_config']['name']`
            try:
                conf_dict = next(
                    config
                    for config in configs
                    if name == config["hf_config"]["name"]
                    or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
                    == name
                )
            except StopIteration:
                raise ValueError(f"{name!r} is not a supported config name")
        else:
            conf_dict = name_to_config[name]

        conf_dict = conf_dict.copy()
        conf_dict.update(kwargs)
        return cls(**conf_dict)

    @classmethod
    def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
        with open(path, encoding="utf-8") as fp:
            file_kwargs = yaml.safe_load(fp)
            if file_kwargs is None:
                raise ValueError(f"{path} is empty which is likely unexpected.")
        file_kwargs.update(kwargs)
        return cls(**file_kwargs)

    @classmethod
    def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
        """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
        if (config_path := path / "model_config.yaml").is_file():
            return cls.from_file(config_path, **kwargs)
        if (model_name := path.name) in name_to_config:
            return cls.from_name(model_name, **kwargs)
        raise FileNotFoundError(
            f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
        )

    @property
    def mlp_class(self) -> Type:
        # `self.mlp_class_name` cannot be the type to keep the config serializable
        return getattr(litgpt.model, self.mlp_class_name)

    @property
    def norm_class(self) -> Type:
        # `self.norm_class_name` cannot be the type to keep the config serializable
        if self.norm_class_name == "RMSNorm":
            from functools import partial

            from litgpt.model import RMSNorm

            return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
        return getattr(torch.nn, self.norm_class_name)


configs = []
name_to_config = {config["name"]: config for config in configs}