Hakka / compress_model.py
Naozumi0512's picture
init
1a79a73
raw
history blame
2.42 kB
from collections import OrderedDict
from text.symbols import symbols
import torch
from tools.log import logger
import utils
from models import SynthesizerTrn
import os
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ",".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict
def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
hps = utils.get_hparams_from_file(config)
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
optim_g = torch.optim.AdamW(
net_g.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
state_dict_g = torch.load(input_model, map_location="cpu")
new_dict_g = copyStateDict(state_dict_g)
keys = []
for k, v in new_dict_g["model"].items():
if "enc_q" in k:
continue # noqa: E701
keys.append(k)
new_dict_g = (
{k: new_dict_g["model"][k].half() for k in keys}
if ishalf
else {k: new_dict_g["model"][k] for k in keys}
)
torch.save(
{
"model": new_dict_g,
"iteration": 0,
"optimizer": optim_g.state_dict(),
"learning_rate": 0.0001,
},
output_model,
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="configs/config.json")
parser.add_argument("-i", "--input", type=str)
parser.add_argument("-o", "--output", type=str, default=None)
parser.add_argument(
"-hf", "--half", action="store_true", default=False, help="Save as FP16"
)
args = parser.parse_args()
output = args.output
if output is None:
import os.path
filename, ext = os.path.splitext(args.input)
half = "_half" if args.half else ""
output = filename + "_release" + half + ext
removeOptimizer(args.config, args.input, args.half, output)
logger.info(f"压缩模型成功, 输出模型: {os.path.abspath(output)}")