# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from copy import deepcopy from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal, Optional, Type, Union import torch import yaml from typing_extensions import Self import litgpt.model from litgpt.utils import find_multiple @dataclass class Config: name: str = "" hf_config: dict = field(default_factory=dict) scale_embeddings: bool = False block_size: int = 4096 vocab_size: int = 50254 padding_multiple: int = 512 padded_vocab_size: Optional[int] = None n_layer: int = 16 n_head: int = 32 head_size: Optional[int] = None n_embd: int = 4096 rotary_percentage: float = 0.25 parallel_residual: bool = True bias: bool = True lm_head_bias: bool = False # to use multi-head attention (MHA), set this to `n_head` (default) # to use multi-query attention (MQA), set this to 1 # to use grouped-query attention (GQA), set this to a value in between # Example with `n_head=4` # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ # │ │ │ │ │ │ │ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ # MHA GQA MQA # n_query_groups=4 n_query_groups=2 n_query_groups=1 # # credit https://arxiv.org/pdf/2305.13245.pdf n_query_groups: Optional[int] = None shared_attention_norm: bool = False norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" norm_eps: float = 1e-5 mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = ( "GptNeoxMLP" ) gelu_approximate: str = "none" intermediate_size: Optional[int] = None rope_condense_ratio: int = 1 rope_base: int = 10000 n_expert: int = 0 n_expert_per_token: int = 0 add_qkv_bias: Optional[bool] = None prompt_vocab_size: Optional[int] = None attn_dropout: float = 0.0 pos_type: str = "rope" force_align: bool = False use_pretrain_phoneme_emb: bool = False tie_word_embeddings: bool = False # setting for mini-omni text_vocab_size:int = 152000 cat_audio_vocab_size: int = 29120 audio_vocab_size: int = 4160 whisper_adapter_dim: int = 768 vision_adapter_dim: int = 512 post_adapter: bool = False post_adapter_layers: int = 6 asr_adapter: str = "llamamlp" def __post_init__(self): if not self.name: self.name = self.hf_config.get("name", self.name) if self.head_size is None: assert self.n_embd % self.n_head == 0 self.head_size = self.n_embd // self.n_head # vocab size should be a power of 2 to be optimal on hardware. compute the closest value if self.padded_vocab_size is None: self.padded_vocab_size = find_multiple( self.vocab_size, self.padding_multiple ) else: # vocab size shouldn't be larger than padded vocab size self.vocab_size = min(self.vocab_size, self.padded_vocab_size) # compute the number of query groups if self.n_query_groups is not None: assert self.n_head % self.n_query_groups == 0 else: self.n_query_groups = self.n_head # compute the intermediate size for MLP if not set if self.intermediate_size is None: if self.mlp_class_name == "LLaMAMLP": raise ValueError( f"The config {self.name!r}, needs to set the `intermediate_size`" ) self.intermediate_size = 4 * self.n_embd self.rope_n_elem = int(self.rotary_percentage * self.head_size) if self.add_qkv_bias is None: self.add_qkv_bias = self.bias @classmethod def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: if name not in name_to_config: # search through all `config['hf_config']['name']` try: conf_dict = next( config for config in configs if name == config["hf_config"]["name"] or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name ) except StopIteration: raise ValueError(f"{name!r} is not a supported config name") else: conf_dict = name_to_config[name] conf_dict = conf_dict.copy() conf_dict.update(kwargs) return cls(**conf_dict) @classmethod def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self: with open(path, encoding="utf-8") as fp: file_kwargs = yaml.safe_load(fp) if file_kwargs is None: raise ValueError(f"{path} is empty which is likely unexpected.") file_kwargs.update(kwargs) return cls(**file_kwargs) @classmethod def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`.""" if (config_path := path / "model_config.yaml").is_file(): return cls.from_file(config_path, **kwargs) if (model_name := path.name) in name_to_config: return cls.from_name(model_name, **kwargs) raise FileNotFoundError( f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists." ) @property def mlp_class(self) -> Type: # `self.mlp_class_name` cannot be the type to keep the config serializable return getattr(litgpt.model, self.mlp_class_name) @property def norm_class(self) -> Type: # `self.norm_class_name` cannot be the type to keep the config serializable if self.norm_class_name == "RMSNorm": from functools import partial from litgpt.model import RMSNorm return partial(RMSNorm, add_unit_offset="Gemma" in self.name) return getattr(torch.nn, self.norm_class_name) configs = [] name_to_config = {config["name"]: config for config in configs}