|
|
|
import argparse, json |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
from pathlib import Path |
|
|
|
parser = argparse.ArgumentParser(description="Convert original dbrx model into quantizable model") |
|
|
|
parser.add_argument("--model-dir", type=str, required=True, help="directory to the original dbrx model") |
|
parser.add_argument("--output-dir", type=str, required=True, help="directory for the converted dbrx model") |
|
args = parser.parse_args() |
|
|
|
model_dir = Path(args.model_dir) |
|
output_dir = Path(args.output_dir) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
NUM_EXPERTS = 16 |
|
HIDDEN_SIZE = 6144 |
|
HEAD_DIM = 128 |
|
NUM_KV_HEAD = 8 |
|
FFN_HIDDEN_SIZE = 10752 |
|
|
|
def change_tensor_attn(tensor): |
|
|
|
return [x.contiguous() for x in tensor.split([HIDDEN_SIZE, NUM_KV_HEAD*HEAD_DIM, NUM_KV_HEAD*HEAD_DIM])] |
|
|
|
def change_attn(tensors): |
|
|
|
keys = list(tensors.keys()) |
|
for k in keys: |
|
if 'Wqkv' in k: |
|
prefix = k.removesuffix('.Wqkv.weight') |
|
tensor = tensors.pop(k) |
|
output_tensor = change_tensor_attn(tensor) |
|
for dtype,t in zip(['q_proj', 'k_proj', 'v_proj'], output_tensor): |
|
tensors[f'{prefix}.{dtype}.weight'] = t |
|
|
|
return tensors |
|
|
|
def change_tensor_mlp(tensor, reverse=False): |
|
|
|
output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)] |
|
|
|
return output |
|
|
|
def change_mlp(tensors): |
|
|
|
keys = list(tensors.keys()) |
|
for k in keys: |
|
if any([x in k for x in ['w1', 'v1', 'w2']]): |
|
prefix,dtype = k.rsplit('.', 1) |
|
tensor = tensors.pop(k) |
|
output_tensor = change_tensor_mlp(tensor, dtype=='w2') |
|
for i,t in enumerate(output_tensor): |
|
tensors[f'{prefix}.{i}.{dtype}.weight'] = t |
|
|
|
return tensors |
|
|
|
for file in sorted(list(model_dir.glob('*.safetensors'))): |
|
print(file) |
|
tensors = {} |
|
with safe_open(file, 'pt') as f: |
|
metadata = f.metadata() |
|
for k in f.keys(): |
|
tensors[k] = f.get_tensor(k) |
|
tensors = change_attn(tensors) |
|
tensors = change_mlp(tensors) |
|
save_file(tensors, (output_dir / file.name).as_posix(), metadata) |
|
|
|
with open(model_dir / 'model.safetensors.index.json') as f: |
|
weight_map = json.load(f) |
|
|
|
weight_keys = list(weight_map['weight_map']) |
|
for k in weight_keys: |
|
if any([x in k for x in ['w1', 'v1', 'w2']]): |
|
prefix,dtype = k.rsplit('.', 1) |
|
value = weight_map['weight_map'].pop(k) |
|
for i in range(NUM_EXPERTS): |
|
weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value |
|
elif 'Wqkv' in k: |
|
prefix = k.removesuffix('.Wqkv.weight') |
|
value = weight_map['weight_map'].pop(k) |
|
for dtype in ['q_proj', 'k_proj', 'v_proj']: |
|
weight_map['weight_map'][f'{prefix}.{dtype}.weight'] = value |
|
|
|
sorted_map = sorted(weight_map['weight_map'].items()) |
|
weight_map['weight_map'] = dict(sorted_map) |
|
|
|
with open(output_dir / 'model.safetensors.index.json', 'w') as f: |
|
json.dump(weight_map, f, indent=4) |
|
|
|
|
|
for filename in os.listdir(model_dir): |
|
if filename.endswith(".safetensors") or filename == "model.safetensors.index.json": |
|
continue |
|
src = os.path.join(model_dir, filename) |
|
dst = os.path.join(output_dir, filename) |
|
if os.path.isfile(src): |
|
shutil.copy2(src, dst) |
|
|