kaizen9's picture
Upload model checkpoints directly from S3
958d6f8 verified
raw
history blame contribute delete
No virus
5.16 kB
"""Helper functions for computing parameter counts for MPT model.
Use if generic `sum(p.numel() for p in self.parameters())`
style computation does not account for MoE parameter sharding.
The helper functions in this file account for MoE parameter
sharding in the parameter count calculation. The functions below
calculate the total parameter count and the active parameter count.
Note: MPT has both n_total_params and n_active_params methods.
"""
from typing import Union
from torch import Tensor, nn
from torch.distributed._tensor import DTensor
from .layers_registry import ffns_with_megablocks
def module_n_params(module: nn.Module) -> int:
"""Gets the number of parameters in this module excluding child modules.
Args:
module (nn.Module): Module of which we get the number of parameters.
Returns:
An int for the number of parameters in this module.
"""
n_params = 0
for p in module.parameters(recurse=False):
n_params += p.numel()
return n_params
def _dtensor_safe_check_numel(tensor: Union[Tensor, DTensor]) -> int:
if isinstance(tensor, DTensor):
tensor = tensor._local_tensor
return tensor.numel()
def megablocks_n_total_params(mpt_model) -> int:
"""Calculates the number of parameters in a MegaBlocks enabled MPT model.
MoE experts are sharded across workers. This function scans for MegaBlocks
modules then multiplies expert params count by MoE world size.
Args:
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
parameters is calculated.
Returns:
An int for the total number of parameters in this MPT model.
"""
import megablocks
moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
n_total_params = 0
for module in mpt_model.modules():
if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
n_w1 = _dtensor_safe_check_numel(module.w1)
n_total_params += n_w1 * moe_world_size
n_w2 = _dtensor_safe_check_numel(module.w2)
n_total_params += n_w2 * moe_world_size
if hasattr(module, 'v1'):
n_v1 = _dtensor_safe_check_numel(module.v1)
n_total_params += n_v1 * moe_world_size
else:
n_total_params += module_n_params(module)
return n_total_params
def megablocks_n_active_params(mpt_model) -> int:
"""Calculates the number of active parameters in a MegaBlocks enabled MPT.
This requires we calculate the number of elements per expert and
multiply this by top k.
Args:
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
active parameters is calculated.
Returns:
An int for the active number of parameters in this MPT model.
"""
import megablocks
moe_num_experts = mpt_model.config.ffn_config.get('moe_num_experts', 1)
moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
local_experts = moe_num_experts / moe_world_size
moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1)
n_active_params = 0
for module in mpt_model.modules():
if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
n_w1 = _dtensor_safe_check_numel(module.w1)
n_active_params += int(n_w1 / local_experts * moe_top_k)
n_w2 = _dtensor_safe_check_numel(module.w2)
n_active_params += int(n_w2 / local_experts * moe_top_k)
if hasattr(module, 'v1'):
n_v1 = _dtensor_safe_check_numel(module.v1)
n_active_params += int(n_v1 / local_experts * moe_top_k)
else:
n_active_params += module_n_params(module)
return n_active_params
def mpt_get_total_params(mpt_model) -> int:
"""Calculates the total parameter count of an MPT model.
Note: Must be called before model parameters are sharded by FSDP.
Args:
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
active parameters is calculated.
Returns:
An int for the total number of parameters in this MPT model.
"""
if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
return megablocks_n_total_params(mpt_model)
else:
return sum((p.numel() for p in mpt_model.parameters()))
def mpt_get_active_params(mpt_model) -> int:
"""Calculates the total parameter count of an MPT model.
Note: Must be called before model parameters are sharded by FSDP.
Args:
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
active parameters is calculated.
Returns:
An int for the active number of parameters in this MPT model.
"""
if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
params = megablocks_n_active_params(mpt_model)
else:
params = sum((p.numel() for p in mpt_model.parameters()))
if not mpt_model.model.transformer.config.tie_word_embeddings:
params -= _dtensor_safe_check_numel(mpt_model.model.transformer.wte.weight)
return params