Spaces:
Runtime error
Runtime error
import torch | |
from huggingface_hub import hf_hub_download | |
from .rmvpe import RMVPE | |
from ..auto_loader import auto_loaded_model | |
def load_rmvpe( | |
rmvpe: str | RMVPE | None = None, device: torch.device = torch.device("cpu") | |
) -> RMVPE: | |
""" | |
Load the RMVPE model from a file or download it if necessary. | |
If a loaded model is provided, it will be returned as is. | |
Args: | |
rmvpe (str | RMVPE | None): The path to the RMVPE model file or the pre-loaded RMVPE model. If None, the default model will be downloaded. | |
device (torch.device): The device to load the model on. | |
Returns: | |
RMVPE: The loaded RMVPE model. | |
Raises: | |
If the model file does not exist. | |
""" | |
if isinstance(rmvpe, RMVPE): | |
return rmvpe.to(device) | |
if isinstance(rmvpe, str): | |
model = RMVPE(4, 1, (2, 2)) | |
model.load_state_dict(torch.load(rmvpe, map_location=device)) | |
model.to(device) | |
return model | |
if "rmvpe" not in auto_loaded_model: | |
rmvpe = hf_hub_download("lj1995/VoiceConversionWebUI", "rmvpe.pt") | |
model = RMVPE(4, 1, (2, 2)) | |
model.load_state_dict(torch.load(rmvpe, map_location="cpu")) | |
model.to(device) | |
auto_loaded_model["rmvpe"] = model | |
return auto_loaded_model["rmvpe"] | |