import argparse, json |
from safetensors import safe_open |
from safetensors.torch import save_file |
from pathlib import Path |
parser = argparse.ArgumentParser(description="Convert original dbrx model into quantizable model") |
parser.add_argument("--model-dir", type=str, required=True, help="directory to the original dbrx model") |
parser.add_argument("--output-dir", type=str, required=True, help="directory for the converted dbrx model") |
args = parser.parse_args() |
model_dir = Path(args.model_dir) |
output_dir = Path(args.output_dir) |
output_dir.mkdir(parents=True, exist_ok=True) |
HIDDEN_SIZE = 6144 |
HEAD_DIM = 128 |
def change_tensor_attn(tensor): |
return [x.contiguous() for x in tensor.split([HIDDEN_SIZE, NUM_KV_HEAD*HEAD_DIM, NUM_KV_HEAD*HEAD_DIM])] |
def change_attn(tensors): |
keys = list(tensors.keys()) |
for k in keys: |
if 'Wqkv' in k: |
prefix = k.removesuffix('.Wqkv.weight') |
tensor = tensors.pop(k) |
output_tensor = change_tensor_attn(tensor) |
for dtype,t in zip(['q_proj', 'k_proj', 'v_proj'], output_tensor): |
tensors[f'{prefix}.{dtype}.weight'] = t |
return tensors |
def change_tensor_mlp(tensor, reverse=False): |
output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)] |
return output |
def change_mlp(tensors): |
keys = list(tensors.keys()) |
for k in keys: |
if any([x in k for x in ['w1', 'v1', 'w2']]): |
prefix,dtype = k.rsplit('.', 1) |
tensor = tensors.pop(k) |
output_tensor = change_tensor_mlp(tensor, dtype=='w2') |
for i,t in enumerate(output_tensor): |
tensors[f'{prefix}.{i}.{dtype}.weight'] = t |
return tensors |
for file in sorted(list(model_dir.glob('*.safetensors'))): |
print(file) |
tensors = {} |
with safe_open(file, 'pt') as f: |
metadata = f.metadata() |
for k in f.keys(): |
tensors[k] = f.get_tensor(k) |
tensors = change_attn(tensors) |
tensors = change_mlp(tensors) |
save_file(tensors, (output_dir / file.name).as_posix(), metadata) |
with open(model_dir / 'model.safetensors.index.json') as f: |
weight_map = json.load(f) |
weight_keys = list(weight_map['weight_map']) |
for k in weight_keys: |
if any([x in k for x in ['w1', 'v1', 'w2']]): |
prefix,dtype = k.rsplit('.', 1) |
value = weight_map['weight_map'].pop(k) |
for i in range(NUM_EXPERTS): |
weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value |
elif 'Wqkv' in k: |
prefix = k.removesuffix('.Wqkv.weight') |
value = weight_map['weight_map'].pop(k) |
for dtype in ['q_proj', 'k_proj', 'v_proj']: |
weight_map['weight_map'][f'{prefix}.{dtype}.weight'] = value |
sorted_map = sorted(weight_map['weight_map'].items()) |
weight_map['weight_map'] = dict(sorted_map) |
with open(output_dir / 'model.safetensors.index.json', 'w') as f: |
json.dump(weight_map, f, indent=4) |
for filename in os.listdir(model_dir): |
if filename.endswith(".safetensors") or filename == "model.safetensors.index.json": |
continue |
src = os.path.join(model_dir, filename) |
dst = os.path.join(output_dir, filename) |
if os.path.isfile(src): |
shutil.copy2(src, dst) |