|
|
|
|
|
import os |
|
import torch |
|
import json |
|
|
|
|
|
from deepspeed_to_megatron import _create_rank_checkpoint, parse_arguments |
|
from typing import Dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
from .configuration_kinoe import KinoeConfig |
|
|
|
|
|
|
|
ZERO_FILE_PREFIX = 'zero_pp_rank_' |
|
LAYER_FILE_PREFIX = 'layer_' |
|
MP_RANK_FILE_PREFIX = 'mp_rank_' |
|
EMBEDDING_LAYER_INDEX = 0 |
|
FINAL_LAYER_NORM_INDEX = -1 |
|
ARGS_KEY = 'args' |
|
ITERATION_KEY = 'iteration' |
|
SEQUENTIAL_LAYERS = [ |
|
'input_layernorm.weight', 'input_layernorm.bias', |
|
'self_attention.dense.bias', |
|
'post_attention_layernorm.weight', 'post_attention_layernorm.bias', |
|
'mlp.dense_4h_to_h.bias', |
|
'position_embeddings.weight' |
|
] |
|
|
|
LAYER_CONCAT_DIM = { |
|
'self_attention.dense.weight': 1, |
|
'mlp.dense_4h_to_h.weight': 1 |
|
} |
|
|
|
class DeepSpeedCheckpoint(object): |
|
def __init__(self, dir, tp_degree=None, pp_degree=None, no_pp=False): |
|
self.dir = dir |
|
self.no_pp = no_pp |
|
self.file_list = self._get_files(dir) |
|
self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) |
|
self.layer_files = self._get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) |
|
self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) |
|
self.layer_keys = self._get_layer_keys() |
|
self.layer_count = len(self.layer_keys) |
|
if not self.no_pp: |
|
self.original_tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01')) |
|
self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree |
|
else: |
|
self.original_tp_degree = len(self.mp_rank_files) |
|
self.original_pp_degree = 1 |
|
self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree) |
|
print(f"dp: {self.dp_degree}") |
|
|
|
self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree |
|
print(f"tp: {self.tp_degree}") |
|
|
|
self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree |
|
print(f"pp: {self.pp_degree}") |
|
|
|
self.global_state = {} |
|
|
|
self._sanity_check() |
|
self.pp_to_transformer_map = self._build_pp_transformer_map() |
|
self.transformer_file_map = self._build_transformer_file_map() |
|
if not self.no_pp: |
|
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) |
|
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) |
|
self._build_global_state() |
|
|
|
|
|
|
|
def show_tp_embedding_map(self): |
|
self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') |
|
|
|
def show_tp_final_norm_map(self): |
|
self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') |
|
|
|
def show_pp_tranformer_map(self): |
|
self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') |
|
|
|
def show_transformer_file_map(self): |
|
self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') |
|
|
|
def _build_global_state(self): |
|
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) |
|
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) |
|
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) |
|
|
|
def get_iteration(self): |
|
if not ITERATION_KEY in self.global_state: |
|
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) |
|
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) |
|
|
|
return self.global_state[ITERATION_KEY] |
|
|
|
def get_embedding_state(self, tp_index: int) -> Dict: |
|
assert tp_index in self.tp_to_embedding_map.keys() |
|
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] |
|
sd = self._merge_state_dicts(sd_list) |
|
return sd |
|
|
|
def get_args(self): |
|
if not ARGS_KEY in self.global_state: |
|
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) |
|
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) |
|
|
|
return self.global_state[ARGS_KEY] |
|
|
|
|
|
def get_transformer_state(self, tp_index: int, pp_index: int) -> list: |
|
assert tp_index < self.tp_degree |
|
assert pp_index < self.pp_degree |
|
t_list = [] |
|
for fname_list in self.transformer_file_map[(tp_index, pp_index)]: |
|
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] |
|
sd = self._merge_state_dicts(sd_list) |
|
t_list.append(sd) |
|
return t_list |
|
|
|
def get_final_norm_state(self, tp_index:int) -> Dict: |
|
assert tp_index in self.tp_to_final_norm_map.keys() |
|
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) |
|
return sd |
|
|
|
def _build_tp_other_layer_map(self, layer_index:int): |
|
assert layer_index < len(self.layer_files) |
|
layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) |
|
layer_file_partitions = self._partition_data(layer_files, self.tp_degree) |
|
data_map = {i:flist for i, flist in enumerate(layer_file_partitions)} |
|
return data_map |
|
|
|
def _build_pp_transformer_map(self): |
|
data_map = {} |
|
transformer_layers = self.layer_keys[1:-1] |
|
layers_per_pp = len(transformer_layers) // self.pp_degree |
|
data_map = {i:transformer_layers[i*layers_per_pp:(i+1)*layers_per_pp] for i in range(0, self.pp_degree)} |
|
return data_map |
|
|
|
def _dump_mapping(self, data_map, map_tag = None): |
|
if map_tag is not None: |
|
print(f'Dump mapping: {map_tag}') |
|
for k, v in data_map.items(): |
|
print(f'{k} = {v}') |
|
|
|
def _build_transformer_file_map(self): |
|
transformer_layer_keys = self.layer_keys[1:-1] |
|
file_map = {} |
|
layers_per_pp = len(transformer_layer_keys) // self.pp_degree |
|
for key_index, layer_key in enumerate(transformer_layer_keys): |
|
pp_index = key_index // layers_per_pp |
|
layer_files = self._get_files_with_prefix(self.layer_files, layer_key) |
|
layer_file_partitions = self._partition_data(layer_files, self.tp_degree) |
|
for tp_index in range(self.tp_degree): |
|
map_key = (tp_index, pp_index) |
|
if not map_key in file_map.keys(): |
|
file_map[map_key] = [] |
|
file_map[map_key].append(layer_file_partitions[tp_index]) |
|
|
|
return file_map |
|
|
|
def _sanity_check(self): |
|
assert len(self.mp_rank_files) % self.tp_degree == 0 |
|
assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 |
|
if not self.no_pp: |
|
assert len(self.layer_keys) > 2 |
|
assert (len(self.layer_keys) - 2) % self.pp_degree == 0 |
|
|
|
def _get_files_with_prefix(self, all_files, prefix): |
|
file_list = [] |
|
for file_path in all_files: |
|
_, fname = os.path.split(file_path) |
|
if fname.startswith(prefix): |
|
file_list.append(file_path) |
|
|
|
return sorted(file_list) |
|
|
|
def validate_files(self): |
|
for file in self.file_list: |
|
if not os.path.isfile(file): |
|
print(f'Error: {file} is not existent') |
|
|
|
def _get_files(self, dir): |
|
file_list = [] |
|
for root, dirs, files in os.walk(dir): |
|
for file in files: |
|
file_list.append(os.path.join(root, file)) |
|
return file_list |
|
|
|
def _get_layer_keys(self): |
|
key_set = set() |
|
key_len = len(LAYER_FILE_PREFIX) + 2 |
|
for file_path in self.layer_files: |
|
_, fname = os.path.split(file_path) |
|
key_set.add(fname[:key_len]) |
|
return sorted(list(key_set)) |
|
|
|
def _partition_data(self, data_list, num_partitions): |
|
num_elems = len(data_list) |
|
assert num_elems % num_partitions == 0 |
|
partition_size = num_elems // num_partitions |
|
partitions_list = [data_list[i:i+partition_size] for i in range(0, num_elems, partition_size)] |
|
return partitions_list |
|
|
|
def _merge_state_dicts(self, sd_list): |
|
merged_sd = {} |
|
for key in sd_list[0].keys(): |
|
if not key in SEQUENTIAL_LAYERS: |
|
cat_dim = LAYER_CONCAT_DIM.get(key, 0) |
|
merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) |
|
else: |
|
merged_sd[key] = sd_list[0][key] |
|
return merged_sd |
|
|
|
|
|
|
|
def convert_wqkv( |
|
qkv_w: torch.Tensor, |
|
n_heads: int = 40, |
|
n_heads_kv: int = 10, |
|
tp_size: int = 1, |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
qkv_w (torch.Tensor): |
|
layer_idx (int, optional): |
|
n_heads (int, optional): |
|
n_heads_kv (int, optional): |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
n_hidden = qkv_w.size(1) |
|
|
|
hidden_dim: int = n_hidden // n_heads * tp_size |
|
|
|
|
|
|
|
n_qs_per_kv: int = n_heads // n_heads_kv |
|
|
|
|
|
n_groups: int = qkv_w.size(0) // hidden_dim // (n_qs_per_kv + 2) |
|
|
|
qkv_w: list[torch.Tensor] = list(torch.split(qkv_w, hidden_dim, dim=0)) |
|
|
|
wq, wk, wv = [], [], [] |
|
for _ in range(n_groups): |
|
for qs in range(n_qs_per_kv): |
|
wq.append(qkv_w[0]) |
|
del qkv_w[0] |
|
wk.append(qkv_w[0]) |
|
del qkv_w[0] |
|
wv.append(qkv_w[0]) |
|
del qkv_w[0] |
|
assert len(qkv_w) == 0 |
|
|
|
wq = torch.concat(wq, dim=0) |
|
wk = torch.concat(wk, dim=0) |
|
wv = torch.concat(wv, dim=0) |
|
return wq, wk, wv |
|
|
|
|
|
def convert_megatron_checkpoint_custom(args, input_state_dict, config): |
|
"""Custom function that converts megatron checkpoints to hf compatible ones for Mistral.""" |
|
output_state_dict = {} |
|
|
|
|
|
ds_args = input_state_dict.get("args", None) |
|
|
|
|
|
torch_dtype = torch.float32 |
|
if ds_args.bf16: |
|
torch_dtype = torch.bfloat16 |
|
elif ds_args.fp16: |
|
torch_dtype = torch.float16 |
|
|
|
|
|
if ds_args is not None: |
|
config.torch_dtype = torch_dtype |
|
|
|
model = input_state_dict["model"] |
|
lm = model["language_model"] |
|
embeddings = lm["embedding"] |
|
encoder = lm["encoder"] |
|
|
|
|
|
output_state_dict["model.embed_tokens.weight"] = embeddings["word_embeddings"]['weight'] |
|
|
|
|
|
encoder_num_layers = config.num_hidden_layers |
|
for i in range(encoder_num_layers): |
|
|
|
output_state_dict[f"model.layers.{i}.input_layernorm.weight"] = encoder[f"layers.{i}.input_layernorm.weight"] |
|
|
|
|
|
qkv_weight = encoder[f"layers.{i}.self_attention.query_key_value.weight"] |
|
|
|
q_proj, k_proj, v_proj = convert_wqkv(qkv_weight, config.num_attention_heads, config.num_key_value_heads) |
|
|
|
output_state_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = q_proj |
|
|
|
output_state_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = k_proj |
|
|
|
output_state_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = v_proj |
|
|
|
output_state_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = encoder[f"layers.{i}.self_attention.dense.weight"] |
|
|
|
dense_h_to_4h_weight = encoder[f"layers.{i}.mlp.dense_h_to_4h.weight"] |
|
split_size = dense_h_to_4h_weight.size(0) // 2 |
|
|
|
output_state_dict[f"model.layers.{i}.mlp.gate_proj.weight"], output_state_dict[f"model.layers.{i}.mlp.up_proj.weight"] = torch.split(dense_h_to_4h_weight, split_size, dim=0) |
|
|
|
output_state_dict[f"model.layers.{i}.mlp.down_proj.weight"] = encoder[f"layers.{i}.mlp.dense_4h_to_h.weight"] |
|
|
|
|
|
output_state_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = encoder[f"layers.{i}.post_attention_layernorm.weight"] |
|
|
|
|
|
output_state_dict["model.norm.weight"] = encoder[f"layers.{encoder_num_layers}.weight"] |
|
|
|
|
|
output_state_dict["lm_head.weight"] = model['word_embeddings_for_head']['weight'] |
|
|
|
return output_state_dict |
|
|
|
|
|
def validate_conversion(ds_model, hf_model, dtype): |
|
seed = 1234 |
|
tensor = torch.random((1, 2048), dtype=dtype) |
|
|
|
|
|
|
|
return |
|
|
|
def load_from_hf_checkpoint(cp_path): |
|
from transformers import AutoModelForCausalLM |
|
model = AutoModelForCausalLM.from_pretrained(cp_path, device_map="auto") |
|
return model |
|
|
|
|
|
def main(): |
|
|
|
|
|
args = parse_arguments() |
|
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}') |
|
|
|
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) |
|
iteration = ds_checkpoint.get_iteration() |
|
input_state_dict = _create_rank_checkpoint(ds_checkpoint, 0, 0, args.for_release) |
|
|
|
|
|
config = KinoeConfig( |
|
vocab_size=55424, |
|
hidden_size=5120, |
|
intermeditate_size=14336, |
|
num_hidden_layers=24, |
|
num_attention_heads=40, |
|
num_key_value_heads=10, |
|
hidden_act="silu", |
|
max_position_embeddings=37268, |
|
rms_norm_eps=1e-6, |
|
use_cache=True, |
|
tie_word_embeddings=False, |
|
rope_theta=10000, |
|
sliding_window=1024, |
|
attention_dropout= 0.0, |
|
max_sequence_length=2048, |
|
pad_token_id=None, |
|
bos_token_id=1, |
|
eos_token_id=2, |
|
) |
|
|
|
|
|
print("Converting to HF Checkpoint") |
|
output_state_dict = convert_megatron_checkpoint_custom(args, input_state_dict, config) |
|
|
|
basename = args.output_folder |
|
os.makedirs(basename, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
output_config_file = os.path.join(basename, "config.json") |
|
output_config = config.to_dict() |
|
output_config["architectures"] = ["KinoeForCausalLM"] |
|
output_config["model_type"] = "gpt" |
|
print(f'Saving config to "{output_config_file}"') |
|
with open(output_config_file, "w") as f: |
|
json.dump(output_config, f) |
|
|
|
|
|
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") |
|
print(f'Saving checkpoint to "{output_checkpoint_file}"') |
|
torch.save(output_state_dict, output_checkpoint_file) |
|
|
|
|
|
model = load_from_hf_checkpoint(basename) |
|
print("Loaded hf model") |
|
|
|
print("Now add tokenizer files and upload to the hub") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|