File size: 5,811 Bytes
2cc518e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Any
import torch
from .attention import ATTN_CLASS_REGISTRY
from .blocks import MPTBlock
from .ffn import FFN_CLASS_REGISTRY
from .norm import NORM_CLASS_REGISTRY

def pass_on_block_idx(parent: torch.nn.Module):
    if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'):
        return
    for child in parent.children():
        child.block_idx = parent.block_idx
        child.max_block_idx = parent.max_block_idx
        if child.children():
            pass_on_block_idx(child)

def get_act_ckpt_module(mod_name: str) -> Any:
    """Get the module type from the module name."""
    if mod_name.lower() == 'mptblock':
        mod_type = MPTBlock
    elif mod_name in ATTN_CLASS_REGISTRY:
        mod_type = ATTN_CLASS_REGISTRY[mod_name]
    elif mod_name in FFN_CLASS_REGISTRY:
        mod_type = FFN_CLASS_REGISTRY[mod_name]
    elif mod_name in NORM_CLASS_REGISTRY:
        mod_type = NORM_CLASS_REGISTRY[mod_name]
    else:
        msg = ', '.join(list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
        raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.')
    return mod_type

def parse_ele_str(ele: str, max_block_idx: int) -> list:
    """Parse a string in target_blocks and return a list of block ids to add.

    Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
    to the first n, the middle m,  the last k, and the range [i, j).
    """
    to_add = None
    if ele.startswith('first-'):
        assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
        to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
    elif ele.startswith('last-'):
        assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
        to_add = list(range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
    elif ele.startswith('middle-'):
        assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
        num = int(ele[7:])
        start = max(max_block_idx // 2 - num // 2, 0)
        end = min(start + num, max_block_idx + 1)
        to_add = list(range(start, end))
    elif ele.startswith('range-'):
        r = ele[6:].split('-')
        assert len(r) == 2, f'Invalid target_blocks element {ele}'
        start, end = (int(r[0]), int(r[1]))
        start = max(start, 0)
        end = min(end, max_block_idx + 1)
        to_add = list(range(start, end))
    else:
        raise ValueError(f'Invalid target_blocks element {ele}')
    return to_add

def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
    """Parse the user input and return a list of block ids."""
    candidate_block_ids = []
    if isinstance(target_blocks, int):
        candidate_block_ids = list(range(target_blocks))
    elif isinstance(target_blocks, list):
        for ele in target_blocks:
            if isinstance(ele, int):
                candidate_block_ids.append(ele)
            elif isinstance(ele, str):
                to_add = parse_ele_str(ele, max_block_idx)
                candidate_block_ids.extend(to_add)
            else:
                raise ValueError(f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}')
    elif isinstance(target_blocks, str):
        target_blocks = target_blocks.replace(' ', '')
        for ele in target_blocks.split(','):
            to_add = parse_ele_str(ele, max_block_idx)
            candidate_block_ids.extend(to_add)
    else:
        raise ValueError(f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}')
    candidate_block_ids = list(set(candidate_block_ids))
    return candidate_block_ids

def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
    """Check if the block ids in the mapping overlap with each other."""
    all_blocks = [None] * (max_block_idx + 1)
    for k, v in mapping.items():
        if v == -1:
            v = list(range(max_block_idx + 1))
        for vv in v:
            if vv < 0 or vv > max_block_idx:
                continue
            elif all_blocks[vv] is not None:
                raise ValueError(f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.')
            else:
                all_blocks[vv] = k

def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, max_block_idx: int) -> dict:
    act_ckpt_mod_to_blocks = {}
    if act_ckpt_target is None or act_ckpt_target == []:
        mod = top_module
        act_ckpt_mod_to_blocks[mod] = -1
    elif isinstance(act_ckpt_target, str):
        mod = get_act_ckpt_module(act_ckpt_target)
        act_ckpt_mod_to_blocks[mod] = -1
    elif isinstance(act_ckpt_target, list):
        for target in act_ckpt_target:
            mod = get_act_ckpt_module(target)
            act_ckpt_mod_to_blocks[mod] = -1
    elif isinstance(act_ckpt_target, dict):
        for k, v in act_ckpt_target.items():
            mod = get_act_ckpt_module(k)
            block_ids = get_target_block_list(v, max_block_idx)
            act_ckpt_mod_to_blocks[mod] = block_ids
    else:
        raise ValueError(f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}')
    return act_ckpt_mod_to_blocks