File size: 1,859 Bytes
587b6c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Utility functions
"""
import pickle
from pathlib import Path

import pax
import toml
import yaml

from tacotron import Tacotron


def load_tacotron_config(config_file=Path("tacotron.toml")):
    """
    Load the project configurations
    """
    return toml.load(config_file)["tacotron"]


def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
    """
    load checkpoint from disk
    """
    with open(path, "rb") as f:
        dic = pickle.load(f)
    if net is not None:
        net = net.load_state_dict(dic["model_state_dict"])
    if optim is not None:
        optim = optim.load_state_dict(dic["optim_state_dict"])
    return dic["step"], net, optim


def create_tacotron_model(config):
    """
    return a random initialized Tacotron model
    """
    return Tacotron(
        mel_dim=config["MEL_DIM"],
        attn_bias=config["ATTN_BIAS"],
        rr=config["RR"],
        max_rr=config["MAX_RR"],
        mel_min=config["MEL_MIN"],
        sigmoid_noise=config["SIGMOID_NOISE"],
        pad_token=config["PAD_TOKEN"],
        prenet_dim=config["PRENET_DIM"],
        attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
        attn_rnn_dim=config["ATTN_RNN_DIM"],
        rnn_dim=config["RNN_DIM"],
        postnet_dim=config["POSTNET_DIM"],
        text_dim=config["TEXT_DIM"],
    )


def load_wavegru_config(config_file):
    """
    Load project configurations
    """
    with open(config_file, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def load_wavegru_ckpt(net, optim, ckpt_file):
    """
    load training checkpoint from file
    """
    with open(ckpt_file, "rb") as f:
        dic = pickle.load(f)

    if net is not None:
        net = net.load_state_dict(dic["net_state_dict"])
    if optim is not None:
        optim = optim.load_state_dict(dic["optim_state_dict"])
    return dic["step"], net, optim