#!/usr/bin/env python import os import torch import json # from deepspeed_checkpoint import DeepSpeedCheckpoint from deepspeed_to_megatron import _create_rank_checkpoint, parse_arguments from typing import Dict # the import was tested to work with this version # https://github.com/huggingface/transformers/commit/0af901e83 if it diverges we may consider # copying that version here instead # from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint import convert_megatron_checkpoint # from transformers import MistralConfig from .configuration_kinoe import KinoeConfig # ----------- temporal fix for relative import issue ZERO_FILE_PREFIX = 'zero_pp_rank_' LAYER_FILE_PREFIX = 'layer_' MP_RANK_FILE_PREFIX = 'mp_rank_' EMBEDDING_LAYER_INDEX = 0 FINAL_LAYER_NORM_INDEX = -1 ARGS_KEY = 'args' ITERATION_KEY = 'iteration' SEQUENTIAL_LAYERS = [ 'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight', 'post_attention_layernorm.bias', 'mlp.dense_4h_to_h.bias', 'position_embeddings.weight' ] LAYER_CONCAT_DIM = { 'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1 } class DeepSpeedCheckpoint(object): def __init__(self, dir, tp_degree=None, pp_degree=None, no_pp=False): self.dir = dir self.no_pp = no_pp self.file_list = self._get_files(dir) self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) self.layer_files = self._get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) self.layer_keys = self._get_layer_keys() self.layer_count = len(self.layer_keys) if not self.no_pp: self.original_tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01')) self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree else: self.original_tp_degree = len(self.mp_rank_files) self.original_pp_degree = 1 self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree) print(f"dp: {self.dp_degree}") #self.dp_degree = 24 self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree print(f"tp: {self.tp_degree}") #self.tp_degree = 1 self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree print(f"pp: {self.pp_degree}") #self.pp_degree = 1 self.global_state = {} self._sanity_check() self.pp_to_transformer_map = self._build_pp_transformer_map() self.transformer_file_map = self._build_transformer_file_map() if not self.no_pp: self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) self._build_global_state() def show_tp_embedding_map(self): self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') def show_tp_final_norm_map(self): self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') def show_pp_tranformer_map(self): self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') def show_transformer_file_map(self): self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') def _build_global_state(self): sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) def get_iteration(self): if not ITERATION_KEY in self.global_state: sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) return self.global_state[ITERATION_KEY] def get_embedding_state(self, tp_index: int) -> Dict: assert tp_index in self.tp_to_embedding_map.keys() sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] sd = self._merge_state_dicts(sd_list) return sd def get_args(self): if not ARGS_KEY in self.global_state: sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) return self.global_state[ARGS_KEY] def get_transformer_state(self, tp_index: int, pp_index: int) -> list: assert tp_index < self.tp_degree assert pp_index < self.pp_degree t_list = [] for fname_list in self.transformer_file_map[(tp_index, pp_index)]: sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] sd = self._merge_state_dicts(sd_list) t_list.append(sd) return t_list def get_final_norm_state(self, tp_index:int) -> Dict: assert tp_index in self.tp_to_final_norm_map.keys() sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) return sd def _build_tp_other_layer_map(self, layer_index:int): assert layer_index < len(self.layer_files) layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) layer_file_partitions = self._partition_data(layer_files, self.tp_degree) data_map = {i:flist for i, flist in enumerate(layer_file_partitions)} return data_map def _build_pp_transformer_map(self): data_map = {} transformer_layers = self.layer_keys[1:-1] layers_per_pp = len(transformer_layers) // self.pp_degree data_map = {i:transformer_layers[i*layers_per_pp:(i+1)*layers_per_pp] for i in range(0, self.pp_degree)} return data_map def _dump_mapping(self, data_map, map_tag = None): if map_tag is not None: print(f'Dump mapping: {map_tag}') for k, v in data_map.items(): print(f'{k} = {v}') def _build_transformer_file_map(self): transformer_layer_keys = self.layer_keys[1:-1] file_map = {} layers_per_pp = len(transformer_layer_keys) // self.pp_degree for key_index, layer_key in enumerate(transformer_layer_keys): pp_index = key_index // layers_per_pp layer_files = self._get_files_with_prefix(self.layer_files, layer_key) layer_file_partitions = self._partition_data(layer_files, self.tp_degree) for tp_index in range(self.tp_degree): map_key = (tp_index, pp_index) if not map_key in file_map.keys(): file_map[map_key] = [] file_map[map_key].append(layer_file_partitions[tp_index]) return file_map def _sanity_check(self): assert len(self.mp_rank_files) % self.tp_degree == 0 assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 if not self.no_pp: assert len(self.layer_keys) > 2 assert (len(self.layer_keys) - 2) % self.pp_degree == 0 def _get_files_with_prefix(self, all_files, prefix): file_list = [] for file_path in all_files: _, fname = os.path.split(file_path) if fname.startswith(prefix): file_list.append(file_path) return sorted(file_list) def validate_files(self): for file in self.file_list: if not os.path.isfile(file): print(f'Error: {file} is not existent') def _get_files(self, dir): file_list = [] for root, dirs, files in os.walk(dir): for file in files: file_list.append(os.path.join(root, file)) return file_list def _get_layer_keys(self): key_set = set() key_len = len(LAYER_FILE_PREFIX) + 2 for file_path in self.layer_files: _, fname = os.path.split(file_path) key_set.add(fname[:key_len]) return sorted(list(key_set)) def _partition_data(self, data_list, num_partitions): num_elems = len(data_list) assert num_elems % num_partitions == 0 partition_size = num_elems // num_partitions partitions_list = [data_list[i:i+partition_size] for i in range(0, num_elems, partition_size)] return partitions_list def _merge_state_dicts(self, sd_list): merged_sd = {} for key in sd_list[0].keys(): if not key in SEQUENTIAL_LAYERS: cat_dim = LAYER_CONCAT_DIM.get(key, 0) merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) else: merged_sd[key] = sd_list[0][key] return merged_sd # ------------------------------ def convert_wqkv( qkv_w: torch.Tensor, # [7680, 5120] n_heads: int = 40, n_heads_kv: int = 10, tp_size: int = 1, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: qkv_w (torch.Tensor): layer_idx (int, optional): n_heads (int, optional): n_heads_kv (int, optional): Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ n_hidden = qkv_w.size(1) # hidden_dim: 128 hidden_dim: int = n_hidden // n_heads * tp_size # print(f"hidden_dim: {hidden_dim}") # n_sq_per_kv: 4 n_qs_per_kv: int = n_heads // n_heads_kv # print(f"n_qs_per_kv {n_qs_per_kv}") # n_groups: 10 n_groups: int = qkv_w.size(0) // hidden_dim // (n_qs_per_kv + 2) # print(f"n_groups: {n_groups}") qkv_w: list[torch.Tensor] = list(torch.split(qkv_w, hidden_dim, dim=0)) wq, wk, wv = [], [], [] for _ in range(n_groups): for qs in range(n_qs_per_kv): wq.append(qkv_w[0]) del qkv_w[0] wk.append(qkv_w[0]) del qkv_w[0] wv.append(qkv_w[0]) del qkv_w[0] assert len(qkv_w) == 0 wq = torch.concat(wq, dim=0) wk = torch.concat(wk, dim=0) wv = torch.concat(wv, dim=0) return wq, wk, wv def convert_megatron_checkpoint_custom(args, input_state_dict, config): """Custom function that converts megatron checkpoints to hf compatible ones for Mistral.""" output_state_dict = {} # old versions did not store training args ds_args = input_state_dict.get("args", None) # check torch dtype torch_dtype = torch.float32 if ds_args.bf16: torch_dtype = torch.bfloat16 elif ds_args.fp16: torch_dtype = torch.float16 # config の修正 if ds_args is not None: config.torch_dtype = torch_dtype model = input_state_dict["model"] lm = model["language_model"] embeddings = lm["embedding"] encoder = lm["encoder"] # model.embed_tokens.weight output_state_dict["model.embed_tokens.weight"] = embeddings["word_embeddings"]['weight'] # layers encoder_num_layers = config.num_hidden_layers for i in range(encoder_num_layers): # layers.{i}.input_layernorm.weight output_state_dict[f"model.layers.{i}.input_layernorm.weight"] = encoder[f"layers.{i}.input_layernorm.weight"] # size (7680, 5120) qkv_weight = encoder[f"layers.{i}.self_attention.query_key_value.weight"] q_proj, k_proj, v_proj = convert_wqkv(qkv_weight, config.num_attention_heads, config.num_key_value_heads) # model.layers.{i}.self_attn.q_proj.weight output_state_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = q_proj # model.layers.{i}.self_attn.k_proj.weight output_state_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = k_proj # model.layers.{i}.self_attn.v_proj.weight output_state_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = v_proj # model.layers.{i}.self_attn.o_proj.weight output_state_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = encoder[f"layers.{i}.self_attention.dense.weight"] dense_h_to_4h_weight = encoder[f"layers.{i}.mlp.dense_h_to_4h.weight"] split_size = dense_h_to_4h_weight.size(0) // 2 # model.layers.{i}.mlp.gate_proj.weight, model.layers.{i}.mlp.up_proj.weight output_state_dict[f"model.layers.{i}.mlp.gate_proj.weight"], output_state_dict[f"model.layers.{i}.mlp.up_proj.weight"] = torch.split(dense_h_to_4h_weight, split_size, dim=0) # model.layers.{i}.mlp.down_proj.weight output_state_dict[f"model.layers.{i}.mlp.down_proj.weight"] = encoder[f"layers.{i}.mlp.dense_4h_to_h.weight"] # model.layers.{i}.post_attention_layernorm.weight output_state_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = encoder[f"layers.{i}.post_attention_layernorm.weight"] # model.norm.weight output_state_dict["model.norm.weight"] = encoder[f"layers.{encoder_num_layers}.weight"] # lm_head.weight output_state_dict["lm_head.weight"] = model['word_embeddings_for_head']['weight'] return output_state_dict #FIXME def validate_conversion(ds_model, hf_model, dtype): seed = 1234 tensor = torch.random((1, 2048), dtype=dtype) # TODO # do inference for each model return def load_from_hf_checkpoint(cp_path): from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(cp_path, device_map="auto") return model def main(): # this first part comes mainly from deepspeed_to_megatron.main args = parse_arguments() print(f'Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}') ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) iteration = ds_checkpoint.get_iteration() input_state_dict = _create_rank_checkpoint(ds_checkpoint, 0, 0, args.for_release) # Get config wiht HF format. config = KinoeConfig( vocab_size=55424, hidden_size=5120, intermeditate_size=14336, num_hidden_layers=24, num_attention_heads=40, num_key_value_heads=10, hidden_act="silu", max_position_embeddings=37268, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000, # refer to Megatron-DeepSpeed/megatron/model/rotay_pos_embedding.py sliding_window=1024, attention_dropout= 0.0, max_sequence_length=2048, pad_token_id=None, bos_token_id=1, eos_token_id=2, ) # Convert. print("Converting to HF Checkpoint") output_state_dict = convert_megatron_checkpoint_custom(args, input_state_dict, config) basename = args.output_folder os.makedirs(basename, exist_ok=True) # Print the structure of converted state dict. #if args.print_checkpoint_structure: # recursive_print(None, output_state_dict) # Store the config to file. output_config_file = os.path.join(basename, "config.json") output_config = config.to_dict() output_config["architectures"] = ["KinoeForCausalLM"] output_config["model_type"] = "gpt" print(f'Saving config to "{output_config_file}"') with open(output_config_file, "w") as f: json.dump(output_config, f) # Store the state_dict to file. output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") print(f'Saving checkpoint to "{output_checkpoint_file}"') torch.save(output_state_dict, output_checkpoint_file) # load hf model model = load_from_hf_checkpoint(basename) print("Loaded hf model") print("Now add tokenizer files and upload to the hub") if __name__ == "__main__": main()