JunzheJosephZhu
commited on
Commit
•
759ffb2
1
Parent(s):
22015d8
Upload 2 files
Browse files- convert_checkpoint.py +45 -0
- pytorch_model.bin +3 -0
convert_checkpoint.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import glob
|
4 |
+
import requests
|
5 |
+
from model import make_model_and_optimizer
|
6 |
+
import torch
|
7 |
+
from asteroid import torch_utils
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
exp_dir = "exp/tmp"
|
11 |
+
# create an exp and checkpoints folder if none exist
|
12 |
+
os.makedirs(os.path.join(exp_dir, "checkpoints"), exist_ok=True)
|
13 |
+
# Download a checkpoint if none exists
|
14 |
+
if len(glob.glob(os.path.join(exp_dir, "checkpoints", "*.ckpt"))) == 0:
|
15 |
+
r = requests.get(
|
16 |
+
"https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN/resolve/main/best-model.ckpt"
|
17 |
+
)
|
18 |
+
with open(os.path.join(exp_dir, "checkpoints", "best-model.ckpt"), "wb") as handle:
|
19 |
+
handle.write(r.content)
|
20 |
+
# if conf doesn't exist, copy default one
|
21 |
+
conf_path = os.path.join(exp_dir, "conf.yml")
|
22 |
+
if not os.path.exists(conf_path):
|
23 |
+
conf_path = "local/conf.yml"
|
24 |
+
# Load training config
|
25 |
+
with open(conf_path) as f:
|
26 |
+
train_conf = yaml.safe_load(f)
|
27 |
+
sample_rate = train_conf["data"]["sample_rate"]
|
28 |
+
best_model_path = os.path.join(exp_dir, "checkpoints", "best-model.ckpt")
|
29 |
+
model, _ = make_model_and_optimizer(train_conf, sample_rate=sample_rate)
|
30 |
+
model.eval()
|
31 |
+
checkpoint = torch.load(best_model_path, map_location="cpu")
|
32 |
+
model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model)
|
33 |
+
model_args = {}
|
34 |
+
model_args.update(train_conf["masknet"])
|
35 |
+
model_args.update(train_conf["filterbank"])
|
36 |
+
new_state_dict = OrderedDict()
|
37 |
+
for k, v in checkpoint["state_dict"].items():
|
38 |
+
new_k = k[k.find(".") + 1 :]
|
39 |
+
new_state_dict[new_k] = v
|
40 |
+
checkpoint["state_dict"] = new_state_dict
|
41 |
+
checkpoint["model_name"] = "MultiDecoderDPRNN"
|
42 |
+
checkpoint["sample_rate"] = sample_rate
|
43 |
+
checkpoint["model_args"] = model_args
|
44 |
+
torch.save(checkpoint, "pytorch_model.bin")
|
45 |
+
print(f"saved checkpoint to pytorch_model.bin")
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad367460c6641c6cd9441ada692c1e28fd9431525f52071f03851b928f659efb
|
3 |
+
size 62269301
|