ntt123's picture
add app
587b6c9
"""
Utility functions
"""
import pickle
from pathlib import Path
import pax
import toml
import yaml
from tacotron import Tacotron
def load_tacotron_config(config_file=Path("tacotron.toml")):
"""
Load the project configurations
"""
return toml.load(config_file)["tacotron"]
def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
"""
load checkpoint from disk
"""
with open(path, "rb") as f:
dic = pickle.load(f)
if net is not None:
net = net.load_state_dict(dic["model_state_dict"])
if optim is not None:
optim = optim.load_state_dict(dic["optim_state_dict"])
return dic["step"], net, optim
def create_tacotron_model(config):
"""
return a random initialized Tacotron model
"""
return Tacotron(
mel_dim=config["MEL_DIM"],
attn_bias=config["ATTN_BIAS"],
rr=config["RR"],
max_rr=config["MAX_RR"],
mel_min=config["MEL_MIN"],
sigmoid_noise=config["SIGMOID_NOISE"],
pad_token=config["PAD_TOKEN"],
prenet_dim=config["PRENET_DIM"],
attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
attn_rnn_dim=config["ATTN_RNN_DIM"],
rnn_dim=config["RNN_DIM"],
postnet_dim=config["POSTNET_DIM"],
text_dim=config["TEXT_DIM"],
)
def load_wavegru_config(config_file):
"""
Load project configurations
"""
with open(config_file, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def load_wavegru_ckpt(net, optim, ckpt_file):
"""
load training checkpoint from file
"""
with open(ckpt_file, "rb") as f:
dic = pickle.load(f)
if net is not None:
net = net.load_state_dict(dic["net_state_dict"])
if optim is not None:
optim = optim.load_state_dict(dic["optim_state_dict"])
return dic["step"], net, optim