File size: 1,286 Bytes
f80c5ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from huggingface_hub import hf_hub_download
from .rmvpe import RMVPE
from ..auto_loader import auto_loaded_model


def load_rmvpe(
    rmvpe: str | RMVPE | None = None, device: torch.device = torch.device("cpu")
) -> RMVPE:
    """
    Load the RMVPE model from a file or download it if necessary.
    If a loaded model is provided, it will be returned as is.

    Args:
        rmvpe (str | RMVPE | None): The path to the RMVPE model file or the pre-loaded RMVPE model. If None, the default model will be downloaded.
        device (torch.device): The device to load the model on.

    Returns:
        RMVPE: The loaded RMVPE model.

    Raises:
        If the model file does not exist.
    """
    if isinstance(rmvpe, RMVPE):
        return rmvpe.to(device)
    if isinstance(rmvpe, str):
        model = RMVPE(4, 1, (2, 2))
        model.load_state_dict(torch.load(rmvpe, map_location=device))
        model.to(device)
        return model
    if "rmvpe" not in auto_loaded_model:
        rmvpe = hf_hub_download("lj1995/VoiceConversionWebUI", "rmvpe.pt")
        model = RMVPE(4, 1, (2, 2))
        model.load_state_dict(torch.load(rmvpe, map_location="cpu"))
        model.to(device)
        auto_loaded_model["rmvpe"] = model
    return auto_loaded_model["rmvpe"]