Kinoe-7B / deepspeed_to_transformers_kinoe.py
AoiKazama's picture
Upload Kinoe-7B
6a8d5c3 verified
#!/usr/bin/env python
import os
import torch
import json
# from deepspeed_checkpoint import DeepSpeedCheckpoint
from deepspeed_to_megatron import _create_rank_checkpoint, parse_arguments
from typing import Dict
# the import was tested to work with this version
# https://github.com/huggingface/transformers/commit/0af901e83 if it diverges we may consider
# copying that version here instead
# from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint import convert_megatron_checkpoint
# from transformers import MistralConfig
from .configuration_kinoe import KinoeConfig
# ----------- temporal fix for relative import issue
ZERO_FILE_PREFIX = 'zero_pp_rank_'
LAYER_FILE_PREFIX = 'layer_'
MP_RANK_FILE_PREFIX = 'mp_rank_'
EMBEDDING_LAYER_INDEX = 0
FINAL_LAYER_NORM_INDEX = -1
ARGS_KEY = 'args'
ITERATION_KEY = 'iteration'
SEQUENTIAL_LAYERS = [
'input_layernorm.weight', 'input_layernorm.bias',
'self_attention.dense.bias',
'post_attention_layernorm.weight', 'post_attention_layernorm.bias',
'mlp.dense_4h_to_h.bias',
'position_embeddings.weight'
]
LAYER_CONCAT_DIM = {
'self_attention.dense.weight': 1,
'mlp.dense_4h_to_h.weight': 1
}
class DeepSpeedCheckpoint(object):
def __init__(self, dir, tp_degree=None, pp_degree=None, no_pp=False):
self.dir = dir
self.no_pp = no_pp
self.file_list = self._get_files(dir)
self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX)
self.layer_files = self._get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX)
self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX)
self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys)
if not self.no_pp:
self.original_tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01'))
self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree
else:
self.original_tp_degree = len(self.mp_rank_files)
self.original_pp_degree = 1
self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree)
print(f"dp: {self.dp_degree}")
#self.dp_degree = 24
self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree
print(f"tp: {self.tp_degree}")
#self.tp_degree = 1
self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree
print(f"pp: {self.pp_degree}")
#self.pp_degree = 1
self.global_state = {}
self._sanity_check()
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
if not self.no_pp:
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
self._build_global_state()
def show_tp_embedding_map(self):
self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers')
def show_tp_final_norm_map(self):
self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers')
def show_pp_tranformer_map(self):
self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers')
def show_transformer_file_map(self):
self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files')
def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)
def get_iteration(self):
if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
return self.global_state[ITERATION_KEY]
def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd = self._merge_state_dicts(sd_list)
return sd
def get_args(self):
if not ARGS_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)
return self.global_state[ARGS_KEY]
def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list
def get_final_norm_state(self, tp_index:int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
return sd
def _build_tp_other_layer_map(self, layer_index:int):
assert layer_index < len(self.layer_files)
layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index])
layer_file_partitions = self._partition_data(layer_files, self.tp_degree)
data_map = {i:flist for i, flist in enumerate(layer_file_partitions)}
return data_map
def _build_pp_transformer_map(self):
data_map = {}
transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {i:transformer_layers[i*layers_per_pp:(i+1)*layers_per_pp] for i in range(0, self.pp_degree)}
return data_map
def _dump_mapping(self, data_map, map_tag = None):
if map_tag is not None:
print(f'Dump mapping: {map_tag}')
for k, v in data_map.items():
print(f'{k} = {v}')
def _build_transformer_file_map(self):
transformer_layer_keys = self.layer_keys[1:-1]
file_map = {}
layers_per_pp = len(transformer_layer_keys) // self.pp_degree
for key_index, layer_key in enumerate(transformer_layer_keys):
pp_index = key_index // layers_per_pp
layer_files = self._get_files_with_prefix(self.layer_files, layer_key)
layer_file_partitions = self._partition_data(layer_files, self.tp_degree)
for tp_index in range(self.tp_degree):
map_key = (tp_index, pp_index)
if not map_key in file_map.keys():
file_map[map_key] = []
file_map[map_key].append(layer_file_partitions[tp_index])
return file_map
def _sanity_check(self):
assert len(self.mp_rank_files) % self.tp_degree == 0
assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0
if not self.no_pp:
assert len(self.layer_keys) > 2
assert (len(self.layer_keys) - 2) % self.pp_degree == 0
def _get_files_with_prefix(self, all_files, prefix):
file_list = []
for file_path in all_files:
_, fname = os.path.split(file_path)
if fname.startswith(prefix):
file_list.append(file_path)
return sorted(file_list)
def validate_files(self):
for file in self.file_list:
if not os.path.isfile(file):
print(f'Error: {file} is not existent')
def _get_files(self, dir):
file_list = []
for root, dirs, files in os.walk(dir):
for file in files:
file_list.append(os.path.join(root, file))
return file_list
def _get_layer_keys(self):
key_set = set()
key_len = len(LAYER_FILE_PREFIX) + 2
for file_path in self.layer_files:
_, fname = os.path.split(file_path)
key_set.add(fname[:key_len])
return sorted(list(key_set))
def _partition_data(self, data_list, num_partitions):
num_elems = len(data_list)
assert num_elems % num_partitions == 0
partition_size = num_elems // num_partitions
partitions_list = [data_list[i:i+partition_size] for i in range(0, num_elems, partition_size)]
return partitions_list
def _merge_state_dicts(self, sd_list):
merged_sd = {}
for key in sd_list[0].keys():
if not key in SEQUENTIAL_LAYERS:
cat_dim = LAYER_CONCAT_DIM.get(key, 0)
merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
else:
merged_sd[key] = sd_list[0][key]
return merged_sd
# ------------------------------
def convert_wqkv(
qkv_w: torch.Tensor, # [7680, 5120]
n_heads: int = 40,
n_heads_kv: int = 10,
tp_size: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
qkv_w (torch.Tensor):
layer_idx (int, optional):
n_heads (int, optional):
n_heads_kv (int, optional):
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
n_hidden = qkv_w.size(1)
# hidden_dim: 128
hidden_dim: int = n_hidden // n_heads * tp_size
# print(f"hidden_dim: {hidden_dim}")
# n_sq_per_kv: 4
n_qs_per_kv: int = n_heads // n_heads_kv
# print(f"n_qs_per_kv {n_qs_per_kv}")
# n_groups: 10
n_groups: int = qkv_w.size(0) // hidden_dim // (n_qs_per_kv + 2)
# print(f"n_groups: {n_groups}")
qkv_w: list[torch.Tensor] = list(torch.split(qkv_w, hidden_dim, dim=0))
wq, wk, wv = [], [], []
for _ in range(n_groups):
for qs in range(n_qs_per_kv):
wq.append(qkv_w[0])
del qkv_w[0]
wk.append(qkv_w[0])
del qkv_w[0]
wv.append(qkv_w[0])
del qkv_w[0]
assert len(qkv_w) == 0
wq = torch.concat(wq, dim=0)
wk = torch.concat(wk, dim=0)
wv = torch.concat(wv, dim=0)
return wq, wk, wv
def convert_megatron_checkpoint_custom(args, input_state_dict, config):
"""Custom function that converts megatron checkpoints to hf compatible ones for Mistral."""
output_state_dict = {}
# old versions did not store training args
ds_args = input_state_dict.get("args", None)
# check torch dtype
torch_dtype = torch.float32
if ds_args.bf16:
torch_dtype = torch.bfloat16
elif ds_args.fp16:
torch_dtype = torch.float16
# config の修正
if ds_args is not None:
config.torch_dtype = torch_dtype
model = input_state_dict["model"]
lm = model["language_model"]
embeddings = lm["embedding"]
encoder = lm["encoder"]
# model.embed_tokens.weight
output_state_dict["model.embed_tokens.weight"] = embeddings["word_embeddings"]['weight']
# layers
encoder_num_layers = config.num_hidden_layers
for i in range(encoder_num_layers):
# layers.{i}.input_layernorm.weight
output_state_dict[f"model.layers.{i}.input_layernorm.weight"] = encoder[f"layers.{i}.input_layernorm.weight"]
# size (7680, 5120)
qkv_weight = encoder[f"layers.{i}.self_attention.query_key_value.weight"]
q_proj, k_proj, v_proj = convert_wqkv(qkv_weight, config.num_attention_heads, config.num_key_value_heads)
# model.layers.{i}.self_attn.q_proj.weight
output_state_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = q_proj
# model.layers.{i}.self_attn.k_proj.weight
output_state_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = k_proj
# model.layers.{i}.self_attn.v_proj.weight
output_state_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = v_proj
# model.layers.{i}.self_attn.o_proj.weight
output_state_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = encoder[f"layers.{i}.self_attention.dense.weight"]
dense_h_to_4h_weight = encoder[f"layers.{i}.mlp.dense_h_to_4h.weight"]
split_size = dense_h_to_4h_weight.size(0) // 2
# model.layers.{i}.mlp.gate_proj.weight, model.layers.{i}.mlp.up_proj.weight
output_state_dict[f"model.layers.{i}.mlp.gate_proj.weight"], output_state_dict[f"model.layers.{i}.mlp.up_proj.weight"] = torch.split(dense_h_to_4h_weight, split_size, dim=0)
# model.layers.{i}.mlp.down_proj.weight
output_state_dict[f"model.layers.{i}.mlp.down_proj.weight"] = encoder[f"layers.{i}.mlp.dense_4h_to_h.weight"]
# model.layers.{i}.post_attention_layernorm.weight
output_state_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = encoder[f"layers.{i}.post_attention_layernorm.weight"]
# model.norm.weight
output_state_dict["model.norm.weight"] = encoder[f"layers.{encoder_num_layers}.weight"]
# lm_head.weight
output_state_dict["lm_head.weight"] = model['word_embeddings_for_head']['weight']
return output_state_dict
#FIXME
def validate_conversion(ds_model, hf_model, dtype):
seed = 1234
tensor = torch.random((1, 2048), dtype=dtype)
# TODO
# do inference for each model
return
def load_from_hf_checkpoint(cp_path):
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(cp_path, device_map="auto")
return model
def main():
# this first part comes mainly from deepspeed_to_megatron.main
args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}')
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp)
iteration = ds_checkpoint.get_iteration()
input_state_dict = _create_rank_checkpoint(ds_checkpoint, 0, 0, args.for_release)
# Get config wiht HF format.
config = KinoeConfig(
vocab_size=55424,
hidden_size=5120,
intermeditate_size=14336,
num_hidden_layers=24,
num_attention_heads=40,
num_key_value_heads=10,
hidden_act="silu",
max_position_embeddings=37268,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000, # refer to Megatron-DeepSpeed/megatron/model/rotay_pos_embedding.py
sliding_window=1024,
attention_dropout= 0.0,
max_sequence_length=2048,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
)
# Convert.
print("Converting to HF Checkpoint")
output_state_dict = convert_megatron_checkpoint_custom(args, input_state_dict, config)
basename = args.output_folder
os.makedirs(basename, exist_ok=True)
# Print the structure of converted state dict.
#if args.print_checkpoint_structure:
# recursive_print(None, output_state_dict)
# Store the config to file.
output_config_file = os.path.join(basename, "config.json")
output_config = config.to_dict()
output_config["architectures"] = ["KinoeForCausalLM"]
output_config["model_type"] = "gpt"
print(f'Saving config to "{output_config_file}"')
with open(output_config_file, "w") as f:
json.dump(output_config, f)
# Store the state_dict to file.
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
print(f'Saving checkpoint to "{output_checkpoint_file}"')
torch.save(output_state_dict, output_checkpoint_file)
# load hf model
model = load_from_hf_checkpoint(basename)
print("Loaded hf model")
print("Now add tokenizer files and upload to the hub")
if __name__ == "__main__":
main()