|
|
|
import os |
|
import yaml |
|
import glob |
|
import requests |
|
from model import make_model_and_optimizer |
|
import torch |
|
from asteroid import torch_utils |
|
from collections import OrderedDict |
|
|
|
exp_dir = "exp/tmp" |
|
|
|
os.makedirs(os.path.join(exp_dir, "checkpoints"), exist_ok=True) |
|
|
|
if len(glob.glob(os.path.join(exp_dir, "checkpoints", "*.ckpt"))) == 0: |
|
r = requests.get( |
|
"https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN/resolve/main/best-model.ckpt" |
|
) |
|
with open(os.path.join(exp_dir, "checkpoints", "best-model.ckpt"), "wb") as handle: |
|
handle.write(r.content) |
|
|
|
conf_path = os.path.join(exp_dir, "conf.yml") |
|
if not os.path.exists(conf_path): |
|
conf_path = "local/conf.yml" |
|
|
|
with open(conf_path) as f: |
|
train_conf = yaml.safe_load(f) |
|
sample_rate = train_conf["data"]["sample_rate"] |
|
best_model_path = os.path.join(exp_dir, "checkpoints", "best-model.ckpt") |
|
model, _ = make_model_and_optimizer(train_conf, sample_rate=sample_rate) |
|
model.eval() |
|
checkpoint = torch.load(best_model_path, map_location="cpu") |
|
model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model) |
|
model_args = {} |
|
model_args.update(train_conf["masknet"]) |
|
model_args.update(train_conf["filterbank"]) |
|
new_state_dict = OrderedDict() |
|
for k, v in checkpoint["state_dict"].items(): |
|
new_k = k[k.find(".") + 1 :] |
|
new_state_dict[new_k] = v |
|
checkpoint["state_dict"] = new_state_dict |
|
checkpoint["model_name"] = "MultiDecoderDPRNN" |
|
checkpoint["sample_rate"] = sample_rate |
|
checkpoint["model_args"] = model_args |
|
torch.save(checkpoint, "pytorch_model.bin") |
|
print(f"saved checkpoint to pytorch_model.bin") |