Spaces:
Running
on
T4
Running
on
T4
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | |
from copy import deepcopy | |
from dataclasses import dataclass, field | |
from pathlib import Path | |
from typing import Any, Literal, Optional, Type, Union | |
import torch | |
import yaml | |
from typing_extensions import Self | |
import litgpt.model | |
from litgpt.utils import find_multiple | |
class Config: | |
name: str = "" | |
hf_config: dict = field(default_factory=dict) | |
scale_embeddings: bool = False | |
block_size: int = 4096 | |
vocab_size: int = 50254 | |
padding_multiple: int = 512 | |
padded_vocab_size: Optional[int] = None | |
n_layer: int = 16 | |
n_head: int = 32 | |
head_size: Optional[int] = None | |
n_embd: int = 4096 | |
rotary_percentage: float = 0.25 | |
parallel_residual: bool = True | |
bias: bool = True | |
lm_head_bias: bool = False | |
# to use multi-head attention (MHA), set this to `n_head` (default) | |
# to use multi-query attention (MQA), set this to 1 | |
# to use grouped-query attention (GQA), set this to a value in between | |
# Example with `n_head=4` | |
# ββββββββββββββββββββ βββββ βββββ βββββ | |
# β v ββ v ββ v ββ v β β v β β v β β v β | |
# ββββββββββββββββββββ βββββ βββββ βββββ | |
# β β β β β β β | |
# ββββββββββββββββββββ βββββ βββββ βββββ | |
# β k ββ k ββ k ββ k β β k β β k β β k β | |
# ββββββββββββββββββββ βββββ βββββ βββββ | |
# β β β β ββββ΄βββ ββββ΄βββ ββββββ¬βββ΄ββ¬βββββ | |
# ββββββββββββββββββββ ββββββββββββββββββββ ββββββββββββββββββββ | |
# β q ββ q ββ q ββ q β β q ββ q ββ q ββ q β β q ββ q ββ q ββ q β | |
# ββββββββββββββββββββ ββββββββββββββββββββ ββββββββββββββββββββ | |
# ββββββββββββββββββββΆ ββββββββββββββββββββΆ ββββββββββββββββββββΆ | |
# MHA GQA MQA | |
# n_query_groups=4 n_query_groups=2 n_query_groups=1 | |
# | |
# credit https://arxiv.org/pdf/2305.13245.pdf | |
n_query_groups: Optional[int] = None | |
shared_attention_norm: bool = False | |
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" | |
norm_eps: float = 1e-5 | |
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = ( | |
"GptNeoxMLP" | |
) | |
gelu_approximate: str = "none" | |
intermediate_size: Optional[int] = None | |
rope_condense_ratio: int = 1 | |
rope_base: int = 10000 | |
n_expert: int = 0 | |
n_expert_per_token: int = 0 | |
add_qkv_bias: Optional[bool] = None | |
prompt_vocab_size: Optional[int] = None | |
attn_dropout: float = 0.0 | |
pos_type: str = "rope" | |
force_align: bool = False | |
use_pretrain_phoneme_emb: bool = False | |
tie_word_embeddings: bool = False | |
# setting for mini-omni | |
text_vocab_size:int = 152000 | |
cat_audio_vocab_size: int = 29120 | |
audio_vocab_size: int = 4160 | |
whisper_adapter_dim: int = 768 | |
post_adapter: bool = False | |
post_adapter_layers: int = 6 | |
asr_adapter: str = "llamamlp" | |
def __post_init__(self): | |
if not self.name: | |
self.name = self.hf_config.get("name", self.name) | |
if self.head_size is None: | |
assert self.n_embd % self.n_head == 0 | |
self.head_size = self.n_embd // self.n_head | |
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value | |
if self.padded_vocab_size is None: | |
self.padded_vocab_size = find_multiple( | |
self.vocab_size, self.padding_multiple | |
) | |
else: | |
# vocab size shouldn't be larger than padded vocab size | |
self.vocab_size = min(self.vocab_size, self.padded_vocab_size) | |
# compute the number of query groups | |
if self.n_query_groups is not None: | |
assert self.n_head % self.n_query_groups == 0 | |
else: | |
self.n_query_groups = self.n_head | |
# compute the intermediate size for MLP if not set | |
if self.intermediate_size is None: | |
if self.mlp_class_name == "LLaMAMLP": | |
raise ValueError( | |
f"The config {self.name!r}, needs to set the `intermediate_size`" | |
) | |
self.intermediate_size = 4 * self.n_embd | |
self.rope_n_elem = int(self.rotary_percentage * self.head_size) | |
if self.add_qkv_bias is None: | |
self.add_qkv_bias = self.bias | |
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: | |
if name not in name_to_config: | |
# search through all `config['hf_config']['name']` | |
try: | |
conf_dict = next( | |
config | |
for config in configs | |
if name == config["hf_config"]["name"] | |
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] | |
== name | |
) | |
except StopIteration: | |
raise ValueError(f"{name!r} is not a supported config name") | |
else: | |
conf_dict = name_to_config[name] | |
conf_dict = conf_dict.copy() | |
conf_dict.update(kwargs) | |
return cls(**conf_dict) | |
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self: | |
with open(path, encoding="utf-8") as fp: | |
file_kwargs = yaml.safe_load(fp) | |
if file_kwargs is None: | |
raise ValueError(f"{path} is empty which is likely unexpected.") | |
file_kwargs.update(kwargs) | |
return cls(**file_kwargs) | |
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: | |
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`.""" | |
if (config_path := path / "model_config.yaml").is_file(): | |
return cls.from_file(config_path, **kwargs) | |
if (model_name := path.name) in name_to_config: | |
return cls.from_name(model_name, **kwargs) | |
raise FileNotFoundError( | |
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists." | |
) | |
def mlp_class(self) -> Type: | |
# `self.mlp_class_name` cannot be the type to keep the config serializable | |
return getattr(litgpt.model, self.mlp_class_name) | |
def norm_class(self) -> Type: | |
# `self.norm_class_name` cannot be the type to keep the config serializable | |
if self.norm_class_name == "RMSNorm": | |
from functools import partial | |
from litgpt.model import RMSNorm | |
return partial(RMSNorm, add_unit_offset="Gemma" in self.name) | |
return getattr(torch.nn, self.norm_class_name) | |
configs = [] | |
name_to_config = {config["name"]: config for config in configs} | |