easyGUI / rvc /synthesizer.py
Blane187's picture
Upload 39 files
c3b58fa verified
raw
history blame contribute delete
No virus
1.86 kB
from collections import OrderedDict
import torch
from .layers.synthesizers import SynthesizerTrnMsNSFsid
from .jit import load_inputs, export_jit_model, save_pickle
def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")):
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
if version == "v1":
encoder_dim = 256
elif version == "v2":
encoder_dim = 768
net_g = SynthesizerTrnMsNSFsid(
*cpt["config"],
encoder_dim=encoder_dim,
use_f0=if_f0 == 1,
)
del net_g.enc_q
net_g.load_state_dict(cpt["weight"], strict=False)
net_g = net_g.float()
net_g.eval().to(device)
net_g.remove_weight_norm()
return net_g, cpt
def load_synthesizer(
pth_path: torch.serialization.FILE_LIKE, device=torch.device("cpu")
):
return get_synthesizer(
torch.load(pth_path, map_location=torch.device("cpu")),
device,
)
def synthesizer_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
from rvc.synthesizer import load_synthesizer
model, cpt = load_synthesizer(model_path, device)
assert isinstance(cpt, dict)
model.forward = model.infer
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export_jit_model(model, mode, inputs, device, is_half)
cpt.pop("weight")
cpt["model"] = ckpt["model"]
cpt["device"] = device
save_pickle(cpt, save_path)
return cpt