|
from typing import Dict, List, Tuple |
|
|
|
import numpy as np |
|
from asteroid import separate |
|
from asteroid.models import BaseModel |
|
import os |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
self.model = BaseModel.from_pretrained(os.path.join(path, "pytorch_model.bin")) |
|
self.sampling_rate = self.model.sample_rate |
|
|
|
def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]: |
|
""" |
|
Args: |
|
inputs (:obj:`np.array`): |
|
The raw waveform of audio received. By default sampled at `self.sampling_rate`. |
|
The shape of this array is `T`, where `T` is the time axis |
|
Return: |
|
A :obj:`tuple` containing: |
|
- :obj:`np.array`: |
|
The return shape of the array must be `C'`x`T'` |
|
- a :obj:`int`: the sampling rate as an int in Hz. |
|
- a :obj:`List[str]`: the annotation for each out channel. |
|
This can be the name of the instruments for audio source separation |
|
or some annotation for speech enhancement. The length must be `C'`. |
|
""" |
|
separated = separate.numpy_separate(self.model, inputs.reshape((1, 1, -1))) |
|
out = separated[0] |
|
n = out.shape[0] |
|
labels = [f"label_{i}" for i in range(n)] |
|
return separated[0], int(self.model.sample_rate), labels |
|
|