|
|
|
|
|
from copy import deepcopy |
|
from dataclasses import dataclass, field |
|
from pathlib import Path |
|
from typing import Any, Literal, Optional, Type, Union |
|
|
|
import torch |
|
import yaml |
|
from typing_extensions import Self |
|
|
|
import litgpt.model |
|
from litgpt.utils import find_multiple |
|
|
|
|
|
@dataclass |
|
class Config: |
|
name: str = "" |
|
hf_config: dict = field(default_factory=dict) |
|
scale_embeddings: bool = False |
|
block_size: int = 4096 |
|
vocab_size: int = 50254 |
|
padding_multiple: int = 512 |
|
padded_vocab_size: Optional[int] = None |
|
n_layer: int = 16 |
|
n_head: int = 32 |
|
head_size: Optional[int] = None |
|
n_embd: int = 4096 |
|
rotary_percentage: float = 0.25 |
|
parallel_residual: bool = True |
|
bias: bool = True |
|
lm_head_bias: bool = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_query_groups: Optional[int] = None |
|
shared_attention_norm: bool = False |
|
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" |
|
norm_eps: float = 1e-5 |
|
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = ( |
|
"GptNeoxMLP" |
|
) |
|
gelu_approximate: str = "none" |
|
intermediate_size: Optional[int] = None |
|
rope_condense_ratio: int = 1 |
|
rope_base: int = 10000 |
|
n_expert: int = 0 |
|
n_expert_per_token: int = 0 |
|
|
|
add_qkv_bias: Optional[bool] = None |
|
prompt_vocab_size: Optional[int] = None |
|
attn_dropout: float = 0.0 |
|
pos_type: str = "rope" |
|
force_align: bool = False |
|
use_pretrain_phoneme_emb: bool = False |
|
tie_word_embeddings: bool = False |
|
|
|
|
|
text_vocab_size:int = 152000 |
|
cat_audio_vocab_size: int = 29120 |
|
audio_vocab_size: int = 4160 |
|
whisper_adapter_dim: int = 768 |
|
|
|
post_adapter: bool = False |
|
post_adapter_layers: int = 6 |
|
asr_adapter: str = "llamamlp" |
|
|
|
def __post_init__(self): |
|
if not self.name: |
|
self.name = self.hf_config.get("name", self.name) |
|
|
|
if self.head_size is None: |
|
assert self.n_embd % self.n_head == 0 |
|
self.head_size = self.n_embd // self.n_head |
|
|
|
|
|
if self.padded_vocab_size is None: |
|
self.padded_vocab_size = find_multiple( |
|
self.vocab_size, self.padding_multiple |
|
) |
|
else: |
|
|
|
self.vocab_size = min(self.vocab_size, self.padded_vocab_size) |
|
|
|
|
|
if self.n_query_groups is not None: |
|
assert self.n_head % self.n_query_groups == 0 |
|
else: |
|
self.n_query_groups = self.n_head |
|
|
|
|
|
if self.intermediate_size is None: |
|
if self.mlp_class_name == "LLaMAMLP": |
|
raise ValueError( |
|
f"The config {self.name!r}, needs to set the `intermediate_size`" |
|
) |
|
self.intermediate_size = 4 * self.n_embd |
|
|
|
self.rope_n_elem = int(self.rotary_percentage * self.head_size) |
|
|
|
if self.add_qkv_bias is None: |
|
self.add_qkv_bias = self.bias |
|
|
|
@classmethod |
|
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: |
|
if name not in name_to_config: |
|
|
|
try: |
|
conf_dict = next( |
|
config |
|
for config in configs |
|
if name == config["hf_config"]["name"] |
|
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] |
|
== name |
|
) |
|
except StopIteration: |
|
raise ValueError(f"{name!r} is not a supported config name") |
|
else: |
|
conf_dict = name_to_config[name] |
|
|
|
conf_dict = conf_dict.copy() |
|
conf_dict.update(kwargs) |
|
return cls(**conf_dict) |
|
|
|
@classmethod |
|
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self: |
|
with open(path, encoding="utf-8") as fp: |
|
file_kwargs = yaml.safe_load(fp) |
|
if file_kwargs is None: |
|
raise ValueError(f"{path} is empty which is likely unexpected.") |
|
file_kwargs.update(kwargs) |
|
return cls(**file_kwargs) |
|
|
|
@classmethod |
|
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: |
|
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`.""" |
|
if (config_path := path / "model_config.yaml").is_file(): |
|
return cls.from_file(config_path, **kwargs) |
|
if (model_name := path.name) in name_to_config: |
|
return cls.from_name(model_name, **kwargs) |
|
raise FileNotFoundError( |
|
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists." |
|
) |
|
|
|
@property |
|
def mlp_class(self) -> Type: |
|
|
|
return getattr(litgpt.model, self.mlp_class_name) |
|
|
|
@property |
|
def norm_class(self) -> Type: |
|
|
|
if self.norm_class_name == "RMSNorm": |
|
from functools import partial |
|
|
|
from litgpt.model import RMSNorm |
|
|
|
return partial(RMSNorm, add_unit_offset="Gemma" in self.name) |
|
return getattr(torch.nn, self.norm_class_name) |
|
|
|
|
|
configs = [] |
|
name_to_config = {config["name"]: config for config in configs} |
|
|