"""Helper functions for computing parameter counts for MPT model. Use if generic `sum(p.numel() for p in self.parameters())` style computation does not account for MoE parameter sharding. The helper functions in this file account for MoE parameter sharding in the parameter count calculation. The functions below calculate the total parameter count and the active parameter count. Note: MPT has both n_total_params and n_active_params methods. """ from typing import Union from torch import Tensor, nn from torch.distributed._tensor import DTensor from .layers_registry import ffns_with_megablocks def module_n_params(module: nn.Module) -> int: """Gets the number of parameters in this module excluding child modules. Args: module (nn.Module): Module of which we get the number of parameters. Returns: An int for the number of parameters in this module. """ n_params = 0 for p in module.parameters(recurse=False): n_params += p.numel() return n_params def _dtensor_safe_check_numel(tensor: Union[Tensor, DTensor]) -> int: if isinstance(tensor, DTensor): tensor = tensor._local_tensor return tensor.numel() def megablocks_n_total_params(mpt_model) -> int: """Calculates the number of parameters in a MegaBlocks enabled MPT model. MoE experts are sharded across workers. This function scans for MegaBlocks modules then multiplies expert params count by MoE world size. Args: mpt_model (ComposerMPTCausalLM): MPT model of which the number of parameters is calculated. Returns: An int for the total number of parameters in this MPT model. """ import megablocks moe_world_size = mpt_model.config.ffn_config.get('moe_world_size') n_total_params = 0 for module in mpt_model.modules(): if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)): n_w1 = _dtensor_safe_check_numel(module.w1) n_total_params += n_w1 * moe_world_size n_w2 = _dtensor_safe_check_numel(module.w2) n_total_params += n_w2 * moe_world_size if hasattr(module, 'v1'): n_v1 = _dtensor_safe_check_numel(module.v1) n_total_params += n_v1 * moe_world_size else: n_total_params += module_n_params(module) return n_total_params def megablocks_n_active_params(mpt_model) -> int: """Calculates the number of active parameters in a MegaBlocks enabled MPT. This requires we calculate the number of elements per expert and multiply this by top k. Args: mpt_model (ComposerMPTCausalLM): MPT model of which the number of active parameters is calculated. Returns: An int for the active number of parameters in this MPT model. """ import megablocks moe_num_experts = mpt_model.config.ffn_config.get('moe_num_experts', 1) moe_world_size = mpt_model.config.ffn_config.get('moe_world_size') local_experts = moe_num_experts / moe_world_size moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1) n_active_params = 0 for module in mpt_model.modules(): if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)): n_w1 = _dtensor_safe_check_numel(module.w1) n_active_params += int(n_w1 / local_experts * moe_top_k) n_w2 = _dtensor_safe_check_numel(module.w2) n_active_params += int(n_w2 / local_experts * moe_top_k) if hasattr(module, 'v1'): n_v1 = _dtensor_safe_check_numel(module.v1) n_active_params += int(n_v1 / local_experts * moe_top_k) else: n_active_params += module_n_params(module) return n_active_params def mpt_get_total_params(mpt_model) -> int: """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. Args: mpt_model (ComposerMPTCausalLM): MPT model of which the number of active parameters is calculated. Returns: An int for the total number of parameters in this MPT model. """ if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: return megablocks_n_total_params(mpt_model) else: return sum((p.numel() for p in mpt_model.parameters())) def mpt_get_active_params(mpt_model) -> int: """Calculates the total parameter count of an MPT model. Note: Must be called before model parameters are sharded by FSDP. Args: mpt_model (ComposerMPTCausalLM): MPT model of which the number of active parameters is calculated. Returns: An int for the active number of parameters in this MPT model. """ if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks: params = megablocks_n_active_params(mpt_model) else: params = sum((p.numel() for p in mpt_model.parameters())) if not mpt_model.model.transformer.config.tie_word_embeddings: params -= _dtensor_safe_check_numel(mpt_model.model.transformer.wte.weight) return params