|
from typing import Any |
|
import torch |
|
from .attention import ATTN_CLASS_REGISTRY |
|
from .blocks import MPTBlock |
|
from .ffn import FFN_CLASS_REGISTRY |
|
from .norm import NORM_CLASS_REGISTRY |
|
|
|
def pass_on_block_idx(parent: torch.nn.Module): |
|
if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'): |
|
return |
|
for child in parent.children(): |
|
child.block_idx = parent.block_idx |
|
child.max_block_idx = parent.max_block_idx |
|
if child.children(): |
|
pass_on_block_idx(child) |
|
|
|
def get_act_ckpt_module(mod_name: str) -> Any: |
|
"""Get the module type from the module name.""" |
|
if mod_name.lower() == 'mptblock': |
|
mod_type = MPTBlock |
|
elif mod_name in ATTN_CLASS_REGISTRY: |
|
mod_type = ATTN_CLASS_REGISTRY[mod_name] |
|
elif mod_name in FFN_CLASS_REGISTRY: |
|
mod_type = FFN_CLASS_REGISTRY[mod_name] |
|
elif mod_name in NORM_CLASS_REGISTRY: |
|
mod_type = NORM_CLASS_REGISTRY[mod_name] |
|
else: |
|
msg = ', '.join(list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) |
|
raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.') |
|
return mod_type |
|
|
|
def parse_ele_str(ele: str, max_block_idx: int) -> list: |
|
"""Parse a string in target_blocks and return a list of block ids to add. |
|
|
|
Supported formats are: first-n, middle-m, last-k, range-i-j which correspond |
|
to the first n, the middle m, the last k, and the range [i, j). |
|
""" |
|
to_add = None |
|
if ele.startswith('first-'): |
|
assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}' |
|
to_add = list(range(min(int(ele[6:]), max_block_idx + 1))) |
|
elif ele.startswith('last-'): |
|
assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}' |
|
to_add = list(range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1)) |
|
elif ele.startswith('middle-'): |
|
assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}' |
|
num = int(ele[7:]) |
|
start = max(max_block_idx // 2 - num // 2, 0) |
|
end = min(start + num, max_block_idx + 1) |
|
to_add = list(range(start, end)) |
|
elif ele.startswith('range-'): |
|
r = ele[6:].split('-') |
|
assert len(r) == 2, f'Invalid target_blocks element {ele}' |
|
start, end = (int(r[0]), int(r[1])) |
|
start = max(start, 0) |
|
end = min(end, max_block_idx + 1) |
|
to_add = list(range(start, end)) |
|
else: |
|
raise ValueError(f'Invalid target_blocks element {ele}') |
|
return to_add |
|
|
|
def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list: |
|
"""Parse the user input and return a list of block ids.""" |
|
candidate_block_ids = [] |
|
if isinstance(target_blocks, int): |
|
candidate_block_ids = list(range(target_blocks)) |
|
elif isinstance(target_blocks, list): |
|
for ele in target_blocks: |
|
if isinstance(ele, int): |
|
candidate_block_ids.append(ele) |
|
elif isinstance(ele, str): |
|
to_add = parse_ele_str(ele, max_block_idx) |
|
candidate_block_ids.extend(to_add) |
|
else: |
|
raise ValueError(f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}') |
|
elif isinstance(target_blocks, str): |
|
target_blocks = target_blocks.replace(' ', '') |
|
for ele in target_blocks.split(','): |
|
to_add = parse_ele_str(ele, max_block_idx) |
|
candidate_block_ids.extend(to_add) |
|
else: |
|
raise ValueError(f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}') |
|
candidate_block_ids = list(set(candidate_block_ids)) |
|
return candidate_block_ids |
|
|
|
def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None: |
|
"""Check if the block ids in the mapping overlap with each other.""" |
|
all_blocks = [None] * (max_block_idx + 1) |
|
for k, v in mapping.items(): |
|
if v == -1: |
|
v = list(range(max_block_idx + 1)) |
|
for vv in v: |
|
if vv < 0 or vv > max_block_idx: |
|
continue |
|
elif all_blocks[vv] is not None: |
|
raise ValueError(f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.') |
|
else: |
|
all_blocks[vv] = k |
|
|
|
def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, max_block_idx: int) -> dict: |
|
act_ckpt_mod_to_blocks = {} |
|
if act_ckpt_target is None or act_ckpt_target == []: |
|
mod = top_module |
|
act_ckpt_mod_to_blocks[mod] = -1 |
|
elif isinstance(act_ckpt_target, str): |
|
mod = get_act_ckpt_module(act_ckpt_target) |
|
act_ckpt_mod_to_blocks[mod] = -1 |
|
elif isinstance(act_ckpt_target, list): |
|
for target in act_ckpt_target: |
|
mod = get_act_ckpt_module(target) |
|
act_ckpt_mod_to_blocks[mod] = -1 |
|
elif isinstance(act_ckpt_target, dict): |
|
for k, v in act_ckpt_target.items(): |
|
mod = get_act_ckpt_module(k) |
|
block_ids = get_target_block_list(v, max_block_idx) |
|
act_ckpt_mod_to_blocks[mod] = block_ids |
|
else: |
|
raise ValueError(f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}') |
|
return act_ckpt_mod_to_blocks |