JunzheJosephZhu commited on
Commit
759ffb2
1 Parent(s): 22015d8

Upload 2 files

Browse files
Files changed (2) hide show
  1. convert_checkpoint.py +45 -0
  2. pytorch_model.bin +3 -0
convert_checkpoint.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import glob
4
+ import requests
5
+ from model import make_model_and_optimizer
6
+ import torch
7
+ from asteroid import torch_utils
8
+ from collections import OrderedDict
9
+
10
+ exp_dir = "exp/tmp"
11
+ # create an exp and checkpoints folder if none exist
12
+ os.makedirs(os.path.join(exp_dir, "checkpoints"), exist_ok=True)
13
+ # Download a checkpoint if none exists
14
+ if len(glob.glob(os.path.join(exp_dir, "checkpoints", "*.ckpt"))) == 0:
15
+ r = requests.get(
16
+ "https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN/resolve/main/best-model.ckpt"
17
+ )
18
+ with open(os.path.join(exp_dir, "checkpoints", "best-model.ckpt"), "wb") as handle:
19
+ handle.write(r.content)
20
+ # if conf doesn't exist, copy default one
21
+ conf_path = os.path.join(exp_dir, "conf.yml")
22
+ if not os.path.exists(conf_path):
23
+ conf_path = "local/conf.yml"
24
+ # Load training config
25
+ with open(conf_path) as f:
26
+ train_conf = yaml.safe_load(f)
27
+ sample_rate = train_conf["data"]["sample_rate"]
28
+ best_model_path = os.path.join(exp_dir, "checkpoints", "best-model.ckpt")
29
+ model, _ = make_model_and_optimizer(train_conf, sample_rate=sample_rate)
30
+ model.eval()
31
+ checkpoint = torch.load(best_model_path, map_location="cpu")
32
+ model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model)
33
+ model_args = {}
34
+ model_args.update(train_conf["masknet"])
35
+ model_args.update(train_conf["filterbank"])
36
+ new_state_dict = OrderedDict()
37
+ for k, v in checkpoint["state_dict"].items():
38
+ new_k = k[k.find(".") + 1 :]
39
+ new_state_dict[new_k] = v
40
+ checkpoint["state_dict"] = new_state_dict
41
+ checkpoint["model_name"] = "MultiDecoderDPRNN"
42
+ checkpoint["sample_rate"] = sample_rate
43
+ checkpoint["model_args"] = model_args
44
+ torch.save(checkpoint, "pytorch_model.bin")
45
+ print(f"saved checkpoint to pytorch_model.bin")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad367460c6641c6cd9441ada692c1e28fd9431525f52071f03851b928f659efb
3
+ size 62269301