File size: 6,641 Bytes
402c662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import json
import sys
from argparse import Namespace
import torch
import os


def load_hyperparam(default_args):
    """
    Load arguments form argparse and config file
    Priority: default options < config file < command line args
    """
    with open(default_args.config_path, mode="r", encoding="utf-8") as f:
        config_args_dict = json.load(f)

    default_args_dict = vars(default_args)

    command_line_args_dict = {k: default_args_dict[k] for k in [
        a[2:] for a in sys.argv if (a[:2] == "--" and "local_rank" not in a)
    ]}
    default_args_dict.update(config_args_dict)
    default_args_dict.update(command_line_args_dict)
    args = Namespace(**default_args_dict)

    return args


def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""):
    # Convert old format to new format if needed from a PyTorch state_dict

    # copy state_dict so _load_from_state_dict can modify it
    state_dict = torch.load(model_path, map_location="cpu")
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    state_dict['target.lm.weight'] = state_dict['target.lm.output_layer.weight']
    del state_dict['target.lm.output_layer.weight']
    state_dict['embedding.embedding.weight'] = state_dict['embedding.word.embedding.weight']
    del state_dict['embedding.word.embedding.weight']

    if metadata is not None:
        metadata['embedding.embedding'] = metadata['embedding.word.embedding']
        metadata['target.lm'] = metadata['target.lm.output_layer']
        if metadata.get('embedding.dropout', None) is not None:
            del metadata['embedding.dropout']
        del metadata['embedding.word']
        del metadata['embedding.word.embedding']
        del metadata['target.lm.output_layer']
        del metadata['target.lm.softmax']
        del metadata['target.lm.criterion']
        state_dict._metadata = metadata

    error_msgs = []

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
    def load(module, state_dict, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
        # Parameters of module and children will start with prefix. We can exit early if there are none in this
        # state_dict
        if len([key for key in state_dict if key.startswith(prefix)]) > 0:
            import deepspeed
            # In sharded models, each shard has only part of the full state_dict, so only gather
            # parameters that are in the current state_dict.
            named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
            params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
            if len(params_to_gather) > 0:
                # because zero3 puts placeholders in model params, this context
                # manager gathers (unpartitions) the params of the current layer, then loads from
                # the state dict and then re-partitions them again
                with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
                    if torch.distributed.get_rank() == 0:
                        module._load_from_state_dict(*args)

        for name, child in module._modules.items():
            if child is not None:
                load(child, state_dict, prefix + name + ".")

    load(model_to_load, state_dict, prefix=start_prefix)
    # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
    # it's safe to delete it.
    del state_dict

    return model_to_load


def convert_normal_parameter_to_int8(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
    import bitsandbytes as bnb
    modules_to_not_convert = ["lm"] if modules_to_not_convert is None else modules_to_not_convert
    for name, module in model.named_children():
        if current_key_name is None:
            current_key_name = []
        current_key_name.append(name)

        if len(list(module.children())) > 0:
            convert_normal_parameter_to_int8(module, threshold, modules_to_not_convert, current_key_name)

        if isinstance(module, bnb.nn.Linear8bitLt) and name not in modules_to_not_convert:
            # Check if the current key is not in the `modules_to_not_convert`
            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
                model._modules[name].weight = bnb.nn.Int8Params(
                    module.weight.data,
                    requires_grad=False,
                    has_fp16_weights=False
                )
                # Force requires grad to False to avoid unexpected errors
                model._modules[name].requires_grad_(False)
        # Remove the last key for recursion
        current_key_name.pop(-1)
    return model


def load_model(model, model_path):
    if os.path.isdir(model_path):
        index_filename = os.path.join(model_path, 'pytorch_model.bin.index.json')
        with open(index_filename, "r") as f:
            index = json.loads(f.read())
        shard_filenames = sorted(set(index["weight_map"].values()))
        shard_filenames = [os.path.join(model_path, f) for f in shard_filenames]
        for shard_file in shard_filenames:
            shard_checkpoint = torch.load(shard_file, map_location='cpu')
            for name, parameter in model.named_parameters():
                if shard_checkpoint.get(name, None) is not None:
                    if 'target' in name:
                        parameter.data = shard_checkpoint['target.lm.output_layer.weight']
                    elif 'embedding' in name:
                        parameter.data = shard_checkpoint['embedding.word.embedding.weight']
                    else:
                        parameter.data = shard_checkpoint[name]
                    parameter.requires_grad = False
            del shard_checkpoint
    else:
        checkpoint = torch.load(model_path, map_location='cpu')
        for parameter_name, parameter in model.named_parameters():
            if 'target' in parameter_name:
                parameter.data = checkpoint['target.lm.output_layer.weight']
            elif 'embedding' in parameter_name:
                parameter.data = checkpoint['embedding.word.embedding.weight']
            else:
                parameter.data = checkpoint[parameter_name]
            parameter.requires_grad = False
        del checkpoint
    return model