from typing import Dict, List, Tuple | |
import numpy as np | |
import os | |
import torch | |
class PreTrainedPipeline(): | |
def __init__(self, path=''): | |
self.model = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq') | |
self.sampling_rate = int(self.model.sample_rate.item()) | |
def __call__(self, inputs): | |
estimates = self.model(inputs.unsqueeze(0)) | |
vocals = estimates[0][0].detach().numpy() | |
n = vocals.shape[0] | |
return vocals, self.sampling_rate, [f"label_{i}" for i in range(n)] | |