Spaces:
Runtime error
Runtime error
File size: 6,641 Bytes
402c662 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import json
import sys
from argparse import Namespace
import torch
import os
def load_hyperparam(default_args):
"""
Load arguments form argparse and config file
Priority: default options < config file < command line args
"""
with open(default_args.config_path, mode="r", encoding="utf-8") as f:
config_args_dict = json.load(f)
default_args_dict = vars(default_args)
command_line_args_dict = {k: default_args_dict[k] for k in [
a[2:] for a in sys.argv if (a[:2] == "--" and "local_rank" not in a)
]}
default_args_dict.update(config_args_dict)
default_args_dict.update(command_line_args_dict)
args = Namespace(**default_args_dict)
return args
def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""):
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = torch.load(model_path, map_location="cpu")
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
state_dict['target.lm.weight'] = state_dict['target.lm.output_layer.weight']
del state_dict['target.lm.output_layer.weight']
state_dict['embedding.embedding.weight'] = state_dict['embedding.word.embedding.weight']
del state_dict['embedding.word.embedding.weight']
if metadata is not None:
metadata['embedding.embedding'] = metadata['embedding.word.embedding']
metadata['target.lm'] = metadata['target.lm.output_layer']
if metadata.get('embedding.dropout', None) is not None:
del metadata['embedding.dropout']
del metadata['embedding.word']
del metadata['embedding.word.embedding']
del metadata['target.lm.output_layer']
del metadata['target.lm.softmax']
del metadata['target.lm.criterion']
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
import deepspeed
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return model_to_load
def convert_normal_parameter_to_int8(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
import bitsandbytes as bnb
modules_to_not_convert = ["lm"] if modules_to_not_convert is None else modules_to_not_convert
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if len(list(module.children())) > 0:
convert_normal_parameter_to_int8(module, threshold, modules_to_not_convert, current_key_name)
if isinstance(module, bnb.nn.Linear8bitLt) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
model._modules[name].weight = bnb.nn.Int8Params(
module.weight.data,
requires_grad=False,
has_fp16_weights=False
)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
current_key_name.pop(-1)
return model
def load_model(model, model_path):
if os.path.isdir(model_path):
index_filename = os.path.join(model_path, 'pytorch_model.bin.index.json')
with open(index_filename, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
shard_filenames = [os.path.join(model_path, f) for f in shard_filenames]
for shard_file in shard_filenames:
shard_checkpoint = torch.load(shard_file, map_location='cpu')
for name, parameter in model.named_parameters():
if shard_checkpoint.get(name, None) is not None:
if 'target' in name:
parameter.data = shard_checkpoint['target.lm.output_layer.weight']
elif 'embedding' in name:
parameter.data = shard_checkpoint['embedding.word.embedding.weight']
else:
parameter.data = shard_checkpoint[name]
parameter.requires_grad = False
del shard_checkpoint
else:
checkpoint = torch.load(model_path, map_location='cpu')
for parameter_name, parameter in model.named_parameters():
if 'target' in parameter_name:
parameter.data = checkpoint['target.lm.output_layer.weight']
elif 'embedding' in parameter_name:
parameter.data = checkpoint['embedding.word.embedding.weight']
else:
parameter.data = checkpoint[parameter_name]
parameter.requires_grad = False
del checkpoint
return model
|