from collections import OrderedDict from typing import List from typing import Tuple from typing import Union import torch from torch_complex.tensor import ComplexTensor from espnet.nets.pytorch_backend.rnn.encoders import RNN from espnet2.enh.separator.abs_separator import AbsSeparator class RNNSeparator(AbsSeparator): def __init__( self, input_dim: int, rnn_type: str = "blstm", num_spk: int = 2, nonlinear: str = "sigmoid", layer: int = 3, unit: int = 512, dropout: float = 0.0, ): """RNN Separator Args: input_dim: input feature dimension rnn_type: string, select from 'blstm', 'lstm' etc. bidirectional: bool, whether the inter-chunk RNN layers are bidirectional. num_spk: number of speakers nonlinear: the nonlinear function for mask estimation, select from 'relu', 'tanh', 'sigmoid' layer: int, number of stacked RNN layers. Default is 3. unit: int, dimension of the hidden state. dropout: float, dropout ratio. Default is 0. """ super().__init__() self._num_spk = num_spk self.rnn = RNN( idim=input_dim, elayers=layer, cdim=unit, hdim=unit, dropout=dropout, typ=rnn_type, ) self.linear = torch.nn.ModuleList( [torch.nn.Linear(unit, input_dim) for _ in range(self.num_spk)] ) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear] def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum, if isinstance(input, ComplexTensor): feature = abs(input) else: feature = input x, ilens, _ = self.rnn(feature, ilens) masks = [] for linear in self.linear: y = linear(x) y = self.nonlinear(y) masks.append(y) masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) return masked, ilens, others @property def num_spk(self): return self._num_spk