File size: 6,091 Bytes
958d6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import Any
import torch
from .layers_registry import attention_classes, ffns, ffns_with_megablocks, ffns_with_norm, norms
from .blocks import FusedNormAttentionNorm, MPTBlock

def pass_on_block_idx(parent: torch.nn.Module):
    if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'):
        return
    for child in parent.children():
        child.block_idx = parent.block_idx
        child.max_block_idx = parent.max_block_idx
        if child.children():
            pass_on_block_idx(child)

def get_act_ckpt_module(mod_name: str) -> Any:
    """Get the module type from the module name."""
    if mod_name.lower() == 'mptblock':
        mod_type = MPTBlock
    elif mod_name in attention_classes:
        mod_type = attention_classes.get(mod_name)
    elif mod_name.lower() == 'norm_attn_norm':
        mod_type = FusedNormAttentionNorm
    elif mod_name in ffns:
        mod_type = ffns.get(mod_name)
    elif mod_name in ffns_with_norm:
        mod_type = ffns_with_norm.get(mod_name)
    elif mod_name in ffns_with_megablocks:
        mod_type = ffns_with_megablocks.get(mod_name)
    elif mod_name in norms:
        mod_type = norms.get(mod_name)
    else:
        msg = ', '.join(list(attention_classes.get_all()) + list(ffns.get_all()) + list(ffns_with_norm.get_all()) + list(ffns_with_megablocks.get_all()) + list(norms.get_all()) + ['MPTBlock'])
        raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.')
    return mod_type

def parse_ele_str(ele: str, max_block_idx: int) -> list:
    """Parse a string in target_blocks and return a list of block ids to add.

    Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
    to the first n, the middle m,  the last k, and the range [i, j).
    """
    to_add = None
    if ele.startswith('first-'):
        assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
        to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
    elif ele.startswith('last-'):
        assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
        to_add = list(range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
    elif ele.startswith('middle-'):
        assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
        num = int(ele[7:])
        start = max(max_block_idx // 2 - num // 2, 0)
        end = min(start + num, max_block_idx + 1)
        to_add = list(range(start, end))
    elif ele.startswith('range-'):
        r = ele[6:].split('-')
        assert len(r) == 2, f'Invalid target_blocks element {ele}'
        start, end = (int(r[0]), int(r[1]))
        start = max(start, 0)
        end = min(end, max_block_idx + 1)
        to_add = list(range(start, end))
    else:
        raise ValueError(f'Invalid target_blocks element {ele}')
    return to_add

def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
    """Parse the user input and return a list of block ids."""
    candidate_block_ids = []
    if isinstance(target_blocks, int):
        candidate_block_ids = list(range(target_blocks))
    elif isinstance(target_blocks, list):
        for ele in target_blocks:
            if isinstance(ele, int):
                candidate_block_ids.append(ele)
            elif isinstance(ele, str):
                to_add = parse_ele_str(ele, max_block_idx)
                candidate_block_ids.extend(to_add)
            else:
                raise ValueError(f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}')
    elif isinstance(target_blocks, str):
        target_blocks = target_blocks.replace(' ', '')
        for ele in target_blocks.split(','):
            to_add = parse_ele_str(ele, max_block_idx)
            candidate_block_ids.extend(to_add)
    else:
        raise ValueError(f'target_blocks must be either a single integer, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}')
    candidate_block_ids = list(set(candidate_block_ids))
    return candidate_block_ids

def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
    """Check if the block ids in the mapping overlap with each other."""
    all_blocks = [None] * (max_block_idx + 1)
    for k, v in mapping.items():
        if v == -1:
            v = list(range(max_block_idx + 1))
        for vv in v:
            if vv < 0 or vv > max_block_idx:
                continue
            elif all_blocks[vv] is not None:
                raise ValueError(f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.')
            else:
                all_blocks[vv] = k

def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, max_block_idx: int) -> dict:
    act_ckpt_mod_to_blocks = {}
    if act_ckpt_target is None or act_ckpt_target == []:
        mod = top_module
        act_ckpt_mod_to_blocks[mod] = -1
    elif isinstance(act_ckpt_target, str):
        mod = get_act_ckpt_module(act_ckpt_target)
        act_ckpt_mod_to_blocks[mod] = -1
    elif isinstance(act_ckpt_target, list):
        for target in act_ckpt_target:
            mod = get_act_ckpt_module(target)
            act_ckpt_mod_to_blocks[mod] = -1
    elif isinstance(act_ckpt_target, dict):
        for k, v in act_ckpt_target.items():
            mod = get_act_ckpt_module(k)
            block_ids = get_target_block_list(v, max_block_idx)
            act_ckpt_mod_to_blocks[mod] = block_ids
    else:
        raise ValueError(f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}')
    return act_ckpt_mod_to_blocks