"""Helper function to configure MPT with MoEs.""" import inspect from typing import Callable, Optional, Union import torch from packaging import version from torch import distributed from torch.distributed._tensor import DeviceMesh from .layers_registry import ffns_with_megablocks from .ffn import resolve_ffn_hidden_size def create_process_group_ranks(ranks: tuple[int, ...]): """Creates a new distributed group. Used in create_set_process_group and create_mod_process_group methods below. This function is an alternative to `distributed.new_group(ranks)`. Args: ranks (tuple[int, ...]): Tuple of ranks of group members. Returns: A handle of distributed group that can be given to collective calls. """ ranks_gather_list = [None for _ in range(distributed.get_world_size())] distributed.all_gather_object(ranks_gather_list, ranks) ranks_per_subgroup = list(set(ranks_gather_list)) group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration(ranks_per_subgroup) return group def create_set_process_group(k: int): """Creates a new distributed group using sets of k GPUs. For example, if you have 16 GPUs and input k=4, the resulting process groups will have ranks: process group 0 ranks: [ 0, 1, 2, 3] process group 1 ranks: [ 4, 5, 6, 7] process group 2 ranks: [ 8, 9, 10, 11] process group 3 ranks: [12, 13, 14, 15] Args: k (int): Number of GPUs to use in set size. Returns: A handle of distributed group that can be given to collective calls. """ world_size = distributed.get_world_size() if world_size % k != 0: raise RuntimeError(f'world_size={world_size!r} must be divisible by k={k!r}.') start = distributed.get_rank() // k * k ranks = tuple(range(start, start + k)) return create_process_group_ranks(ranks) def get_megablocks_device_mesh(device_mesh_cfg: Optional[tuple[int, ...]], moe_world_size: int, world_size: int) -> DeviceMesh: """Helper function to get the device mesh for MegaBlocks MoE. Args: device_mesh_cfg (Optional[tuple[int, ...]]): The device mesh configuration specification. moe_world_size (int): The MoE world size. world_size (int): The world size. Raises: ValueError: If the device mesh configuration is not valid. Returns: The device mesh for MegaBlocks MoE. """ from torch.distributed._tensor.device_mesh import init_device_mesh if device_mesh_cfg is None or len(device_mesh_cfg) == 1: if device_mesh_cfg is not None: world_size = device_mesh_cfg[0] sharding_group_dim = world_size // moe_world_size device_mesh = init_device_mesh('cuda', (sharding_group_dim, moe_world_size), mesh_dim_names=('weight_parallel', 'expert_parallel')) else: raise ValueError(f'device_mesh_cfg={device_mesh_cfg!r} must be length 1') return device_mesh def config_megablocks_moe_args(ffn_config: dict, d_model: int, expansion_ratio: Union[int, float], n_layers: int, get_device_mesh: Callable) -> dict: """Configures `ffn_config` for MegaBlocks MoE. We prepare all necessary arguments for `megablocks.layers.arguments.Arguments` so that process groups can be initialized and shared across all blocks in the network. Args: ffn_config (dict): FFN configuration before the MegaBlocks MoE is configured. d_model (int): Hidden size of the network. expansion_ratio (Union[int, float]): Expansion ratio in FFN. n_layers (int): Number of blocks used in the network. get_device_mesh (Callable): Function to get the device mesh. Takes in the device mesh config and the MoE world size. Returns: ffn_config (dict): FFN configuration with MegaBlocks MoE configured. """ try: import megablocks except: raise RuntimeError('Requirements for MegaBlocks not installed; see install instructions in `README.md`.') ffn_config.setdefault('fp16', False) ffn_config.setdefault('bf16', False) ffn_config['num_layers'] = n_layers ffn_type = ffn_config.pop('ffn_type') fc_type = ffn_config.pop('fc_type') ffn_act_fn = ffn_config.pop('ffn_act_fn', None) world_size = 1 moe_world_size = ffn_config.pop('moe_world_size') device_mesh = None device_mesh_cfg = ffn_config.pop('device_mesh', None) if moe_world_size > 1: if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.2.0'): raise RuntimeError('MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.') world_size = distributed.get_world_size() if world_size < moe_world_size or world_size % moe_world_size: raise ValueError(f'Invalid world size configuration: world_size={world_size!r} and moe_world_size={moe_world_size!r}') device_mesh = get_device_mesh(device_mesh_cfg=device_mesh_cfg, moe_world_size=moe_world_size, world_size=world_size) ffn_config['moe_expert_model_parallelism'] = True ffn_config['expert_parallel_group'] = device_mesh['expert_parallel'].get_group(0) lbl_process_group = ffn_config.get('lbl_process_group', None) if lbl_process_group is not None: if lbl_process_group == 'expert_group': lbl_process_group = ffn_config['expert_parallel_group'] elif lbl_process_group == 'global_group': lbl_process_group = distributed.group.WORLD elif isinstance(lbl_process_group, int): if lbl_process_group > 1: lbl_process_group = create_set_process_group(lbl_process_group) else: lbl_process_group = None elif not isinstance(lbl_process_group, distributed.ProcessGroup): raise ValueError(f'Unknown lbl_process_group={lbl_process_group!r}. Options are: none | a process group | ``expert_group`` | ``global_group`` | .') ffn_config['lbl_process_group'] = lbl_process_group ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio) ffn_config.setdefault('ffn_hidden_size', ffn_hidden_size) args_to_keep_in_ffn_config = inspect.signature(megablocks.layers.arguments.Arguments).parameters ffn_config = {k: v for k, v in ffn_config.items() if k in args_to_keep_in_ffn_config} args = megablocks.layers.arguments.Arguments(hidden_size=d_model, **ffn_config) ffn_config['args'] = args ffn_config['device_mesh'] = device_mesh ffn_config['moe_world_size'] = moe_world_size ffn_config['ffn_type'] = ffn_type ffn_config['fc_type'] = fc_type ffn_config['ffn_act_fn'] = ffn_act_fn return ffn_config def config_moe_args(ffn_config: dict, d_model: int, expansion_ratio: Union[int, float], n_layers: int) -> dict: """Configures `ffn_config` for MoE. Args: ffn_config (dict): FFN configuration before the MoE is configured. d_model (int): Hidden size of the network. expansion_ratio (int, float): Expansion ratio in FFN. n_layers (int): Number of blocks used in the network. Returns: ffn_config (dict): FFN configuration with MoE configured. """ if ffn_config['ffn_type'] in ffns_with_megablocks: return config_megablocks_moe_args(ffn_config=ffn_config, d_model=d_model, expansion_ratio=expansion_ratio, n_layers=n_layers, get_device_mesh=get_megablocks_device_mesh) else: raise ValueError(f"Invalid ffn_type ({ffn_config['ffn_type']}).")