File size: 7,645 Bytes
2776201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional, Type, Union
import torch
import yaml
from typing_extensions import Self
import litgpt.model
from litgpt.utils import find_multiple
@dataclass
class Config:
name: str = ""
hf_config: dict = field(default_factory=dict)
scale_embeddings: bool = False
block_size: int = 4096
vocab_size: int = 50254
padding_multiple: int = 512
padded_vocab_size: Optional[int] = None
n_layer: int = 16
n_head: int = 32
head_size: Optional[int] = None
n_embd: int = 4096
rotary_percentage: float = 0.25
parallel_residual: bool = True
bias: bool = True
lm_head_bias: bool = False
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
# Example with `n_head=4`
# ββββββββββββββββββββ βββββ βββββ βββββ
# β v ββ v ββ v ββ v β β v β β v β β v β
# ββββββββββββββββββββ βββββ βββββ βββββ
# β β β β β β β
# ββββββββββββββββββββ βββββ βββββ βββββ
# β k ββ k ββ k ββ k β β k β β k β β k β
# ββββββββββββββββββββ βββββ βββββ βββββ
# β β β β ββββ΄βββ ββββ΄βββ ββββββ¬βββ΄ββ¬βββββ
# ββββββββββββββββββββ ββββββββββββββββββββ ββββββββββββββββββββ
# β q ββ q ββ q ββ q β β q ββ q ββ q ββ q β β q ββ q ββ q ββ q β
# ββββββββββββββββββββ ββββββββββββββββββββ ββββββββββββββββββββ
# ββββββββββββββββββββΆ ββββββββββββββββββββΆ ββββββββββββββββββββΆ
# MHA GQA MQA
# n_query_groups=4 n_query_groups=2 n_query_groups=1
#
# credit https://arxiv.org/pdf/2305.13245.pdf
n_query_groups: Optional[int] = None
shared_attention_norm: bool = False
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
"GptNeoxMLP"
)
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
rope_base: int = 10000
n_expert: int = 0
n_expert_per_token: int = 0
add_qkv_bias: Optional[bool] = None
prompt_vocab_size: Optional[int] = None
attn_dropout: float = 0.0
pos_type: str = "rope"
force_align: bool = False
use_pretrain_phoneme_emb: bool = False
tie_word_embeddings: bool = False
# setting for mini-omni
text_vocab_size:int = 152000
cat_audio_vocab_size: int = 29120
audio_vocab_size: int = 4160
whisper_adapter_dim: int = 768
post_adapter: bool = False
post_adapter_layers: int = 6
asr_adapter: str = "llamamlp"
def __post_init__(self):
if not self.name:
self.name = self.hf_config.get("name", self.name)
if self.head_size is None:
assert self.n_embd % self.n_head == 0
self.head_size = self.n_embd // self.n_head
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
if self.padded_vocab_size is None:
self.padded_vocab_size = find_multiple(
self.vocab_size, self.padding_multiple
)
else:
# vocab size shouldn't be larger than padded vocab size
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
# compute the number of query groups
if self.n_query_groups is not None:
assert self.n_head % self.n_query_groups == 0
else:
self.n_query_groups = self.n_head
# compute the intermediate size for MLP if not set
if self.intermediate_size is None:
if self.mlp_class_name == "LLaMAMLP":
raise ValueError(
f"The config {self.name!r}, needs to set the `intermediate_size`"
)
self.intermediate_size = 4 * self.n_embd
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
if self.add_qkv_bias is None:
self.add_qkv_bias = self.bias
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
if name not in name_to_config:
# search through all `config['hf_config']['name']`
try:
conf_dict = next(
config
for config in configs
if name == config["hf_config"]["name"]
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
== name
)
except StopIteration:
raise ValueError(f"{name!r} is not a supported config name")
else:
conf_dict = name_to_config[name]
conf_dict = conf_dict.copy()
conf_dict.update(kwargs)
return cls(**conf_dict)
@classmethod
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
with open(path, encoding="utf-8") as fp:
file_kwargs = yaml.safe_load(fp)
if file_kwargs is None:
raise ValueError(f"{path} is empty which is likely unexpected.")
file_kwargs.update(kwargs)
return cls(**file_kwargs)
@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
if (config_path := path / "model_config.yaml").is_file():
return cls.from_file(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
)
@property
def mlp_class(self) -> Type:
# `self.mlp_class_name` cannot be the type to keep the config serializable
return getattr(litgpt.model, self.mlp_class_name)
@property
def norm_class(self) -> Type:
# `self.norm_class_name` cannot be the type to keep the config serializable
if self.norm_class_name == "RMSNorm":
from functools import partial
from litgpt.model import RMSNorm
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
return getattr(torch.nn, self.norm_class_name)
configs = []
name_to_config = {config["name"]: config for config in configs}
|