MultiDecoderDPRNN / convert_checkpoint.py
JunzheJosephZhu's picture
Update convert_checkpoint.py
b53c2fe
# place this file under egs/wsj0-mix-var/Multi-Decoder-DPRNN and execute, to convert best-model.ckpt to pytorch_model.bin
import os
import yaml
import glob
import requests
from model import make_model_and_optimizer
import torch
from asteroid import torch_utils
from collections import OrderedDict
exp_dir = "exp/tmp"
# create an exp and checkpoints folder if none exist
os.makedirs(os.path.join(exp_dir, "checkpoints"), exist_ok=True)
# Download a checkpoint if none exists
if len(glob.glob(os.path.join(exp_dir, "checkpoints", "*.ckpt"))) == 0:
r = requests.get(
"https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN/resolve/main/best-model.ckpt"
)
with open(os.path.join(exp_dir, "checkpoints", "best-model.ckpt"), "wb") as handle:
handle.write(r.content)
# if conf doesn't exist, copy default one
conf_path = os.path.join(exp_dir, "conf.yml")
if not os.path.exists(conf_path):
conf_path = "local/conf.yml"
# Load training config
with open(conf_path) as f:
train_conf = yaml.safe_load(f)
sample_rate = train_conf["data"]["sample_rate"]
best_model_path = os.path.join(exp_dir, "checkpoints", "best-model.ckpt")
model, _ = make_model_and_optimizer(train_conf, sample_rate=sample_rate)
model.eval()
checkpoint = torch.load(best_model_path, map_location="cpu")
model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model)
model_args = {}
model_args.update(train_conf["masknet"])
model_args.update(train_conf["filterbank"])
new_state_dict = OrderedDict()
for k, v in checkpoint["state_dict"].items():
new_k = k[k.find(".") + 1 :]
new_state_dict[new_k] = v
checkpoint["state_dict"] = new_state_dict
checkpoint["model_name"] = "MultiDecoderDPRNN"
checkpoint["sample_rate"] = sample_rate
checkpoint["model_args"] = model_args
torch.save(checkpoint, "pytorch_model.bin")
print(f"saved checkpoint to pytorch_model.bin")