ZeroRVC / zerorvc /f0 /load.py
github-actions[bot]
Sync to HuggingFace Spaces
f80c5ec
raw
history blame
No virus
1.29 kB
import torch
from huggingface_hub import hf_hub_download
from .rmvpe import RMVPE
from ..auto_loader import auto_loaded_model
def load_rmvpe(
rmvpe: str | RMVPE | None = None, device: torch.device = torch.device("cpu")
) -> RMVPE:
"""
Load the RMVPE model from a file or download it if necessary.
If a loaded model is provided, it will be returned as is.
Args:
rmvpe (str | RMVPE | None): The path to the RMVPE model file or the pre-loaded RMVPE model. If None, the default model will be downloaded.
device (torch.device): The device to load the model on.
Returns:
RMVPE: The loaded RMVPE model.
Raises:
If the model file does not exist.
"""
if isinstance(rmvpe, RMVPE):
return rmvpe.to(device)
if isinstance(rmvpe, str):
model = RMVPE(4, 1, (2, 2))
model.load_state_dict(torch.load(rmvpe, map_location=device))
model.to(device)
return model
if "rmvpe" not in auto_loaded_model:
rmvpe = hf_hub_download("lj1995/VoiceConversionWebUI", "rmvpe.pt")
model = RMVPE(4, 1, (2, 2))
model.load_state_dict(torch.load(rmvpe, map_location="cpu"))
model.to(device)
auto_loaded_model["rmvpe"] = model
return auto_loaded_model["rmvpe"]