File size: 12,266 Bytes
5a393d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Phi-MoE model."""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
PHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/Phi-3.5-MoE-instruct": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json",
}
class PhiMoEConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the
[microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32064):
Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`PhiMoEModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 6400):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and
`original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must
be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of
the attention head size and the `original_max_position_embeddings` must be an integer.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `262144`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to root per-token, can be also interpreted as the `top-p` routing
parameter
num_local_experts (`int`, *optional*, defaults to 16):
Number of experts per Sparse MLP layer.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.0):
The aux loss factor for the total loss.
router_jitter_noise (`float`, *optional*, defaults to 0.01):
Amount of noise to add to the router.
```python
>>> from transformers import PhiMoEModel, PhiMoEConfig
>>> # Initializing a Phi-3 style configuration
>>> configuration = PhiMoEConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
>>> # Initializing a model from the configuration
>>> model = PhiMoEModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "phimoe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32064,
hidden_size=4096,
intermediate_size=6400,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
rope_scaling=None,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=16,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.01,
input_jitter_noise=0.0,
attention_bias = False,
lm_head_bias = False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.attention_bias = attention_bias
self.lm_head_bias = lm_head_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
self.input_jitter_noise = input_jitter_noise
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:
raise ValueError(
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, "
f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None)
rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
)
if not isinstance(rope_scaling_short_mscale, (int, float)):
raise ValueError(
f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}"
)
if not isinstance(rope_scaling_long_mscale, (int, float)):
raise ValueError(
f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}"
)
if not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}"
) |