easyGUI / rvc /jit /jit.py
Blane187's picture
Upload 39 files
c3b58fa verified
raw
history blame contribute delete
No virus
2.15 kB
import pickle
from io import BytesIO
from collections import OrderedDict
import os
import torch
def load_pickle(path: str):
with open(path, "rb") as f:
return pickle.load(f)
def save_pickle(ckpt: dict, save_path: str):
with open(save_path, "wb") as f:
pickle.dump(ckpt, f)
def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False):
parm = torch.load(path, map_location=torch.device("cpu"))
for key in parm.keys():
parm[key] = parm[key].to(device)
if is_half and parm[key].dtype == torch.float32:
parm[key] = parm[key].half()
elif not is_half and parm[key].dtype == torch.float16:
parm[key] = parm[key].float()
return parm
def export_jit_model(
model: torch.nn.Module,
mode: str = "trace",
inputs: dict = None,
device=torch.device("cpu"),
is_half: bool = False,
) -> dict:
model = model.half() if is_half else model.float()
model.eval()
if mode == "trace":
assert inputs is not None
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
elif mode == "script":
model_jit = torch.jit.script(model)
model_jit.to(device)
model_jit = model_jit.half() if is_half else model_jit.float()
buffer = BytesIO()
# model_jit=model_jit.cpu()
torch.jit.save(model_jit, buffer)
del model_jit
cpt = OrderedDict()
cpt["model"] = buffer.getvalue()
cpt["is_half"] = is_half
return cpt
def get_jit_model(model_path: str, is_half: bool, device: str, exporter):
jit_model_path = model_path.rstrip(".pth")
jit_model_path += ".half.jit" if is_half else ".jit"
ckpt = None
if os.path.exists(jit_model_path):
ckpt = load_pickle(jit_model_path)
model_device = ckpt["device"]
if model_device != str(device):
del ckpt
ckpt = None
if ckpt is None:
ckpt = exporter(
model_path=model_path,
mode="script",
inputs_path=None,
save_path=jit_model_path,
device=device,
is_half=is_half,
)
return ckpt