Spaces:
Runtime error
Runtime error
import argparse | |
from math import ceil | |
from pathlib import Path | |
import torch | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--ckpt_path", "-c", type=str) | |
args = parser.parse_args() | |
weight_name_map = { | |
"model.encodec_embeddings": None, | |
"encodec_embeddings": "embed_encodec", | |
"encodec_mlm_head": "mcm_heads", | |
} | |
ckpt_path = Path(args.ckpt_path) | |
weight_file = ckpt_path / "pytorch_model.bin" | |
state_dict = torch.load(weight_file, map_location="cpu") | |
new_state_dict = {} | |
for key in state_dict: | |
new_key = key | |
for orig, repl in weight_name_map.items(): | |
if repl is None: | |
if orig in new_key: | |
new_key = None | |
break | |
continue | |
new_key = new_key.replace(orig, repl) | |
if new_key: | |
new_state_dict[new_key] = state_dict[key] | |
for key in new_state_dict: | |
if "model.encoder.embed_encodec" in key: | |
dim = new_state_dict[key].shape[0] | |
new_weight = torch.normal( | |
0, 1, (ceil(dim / 64) * 64, new_state_dict[key].shape[1]) | |
) | |
new_weight[:dim] = new_state_dict[key] | |
new_state_dict[key] = new_weight | |
weight_file.rename(weight_file.with_suffix(".bin.bak")) | |
torch.save(new_state_dict, weight_file) | |