File size: 5,207 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from collections import OrderedDict
from typing import Tuple
import warnings

import torch

from espnet2.enh.separator.abs_separator import AbsSeparator


class AsteroidModel_Converter(AbsSeparator):
    def __init__(
        self,
        encoder_output_dim: int,
        model_name: str,
        num_spk: int,
        pretrained_path: str = "",
        loss_type: str = "si_snr",
        **model_related_kwargs,
    ):
        """The class to convert the models from asteroid to AbsSeprator.

        Args:
            encoder_output_dim: input feature dimension, deafult=1 after the NullEncoder
            num_spk: number of speakers
            loss_type: loss type of enhancement
            model_name: Asteroid model names, e.g. ConvTasNet, DPTNet. Refers to
                        https://github.com/asteroid-team/asteroid/
                        blob/master/asteroid/models/__init__.py
            pretrained_path: the name of pretrained model from Asteroid in HF hub.
                             Refers to: https://github.com/asteroid-team/asteroid/
                             blob/master/docs/source/readmes/pretrained_models.md and
                             https://huggingface.co/models?filter=asteroid
            model_related_kwargs: more args towards each specific asteroid model.
        """
        super(AsteroidModel_Converter, self).__init__()

        assert (
            encoder_output_dim == 1
        ), encoder_output_dim  # The input should in raw-wave domain.

        # Please make sure the installation of Asteroid.
        # https://github.com/asteroid-team/asteroid
        from asteroid import models

        model_related_kwargs = {
            k: None if v == "None" else v for k, v in model_related_kwargs.items()
        }
        # print('args:',model_related_kwargs)

        if pretrained_path:
            model = getattr(models, model_name).from_pretrained(pretrained_path)
            print("model_kwargs:", model_related_kwargs)
            if model_related_kwargs:
                warnings.warn(
                    "Pratrained model should get no args with %s" % model_related_kwargs
                )

        else:
            model_name = getattr(models, model_name)
            model = model_name(**model_related_kwargs)

        self.model = model
        self._num_spk = num_spk

        self.loss_type = loss_type
        if loss_type != "si_snr":
            raise ValueError("Unsupported loss type: %s" % loss_type)

    def forward(self, input: torch.Tensor, ilens: torch.Tensor = None):
        """Whole forward of asteroid models.

        Args:
            input (torch.Tensor): Raw Waveforms [B, T]
            ilens (torch.Tensor): input lengths [B]

        Returns:
            estimated Waveforms(List[Union(torch.Tensor]): [(B, T), ...]
            ilens (torch.Tensor): (B,)
            others predicted data, e.g. masks: OrderedDict[
                'mask_spk1': torch.Tensor(Batch, T),
                'mask_spk2': torch.Tensor(Batch, T),
                ...
                'mask_spkn': torch.Tensor(Batch, T),
            ]
        """

        if hasattr(self.model, "forward_wav"):
            est_source = self.model.forward_wav(input)  # B,nspk,T or nspk,T
        else:
            est_source = self.model(input)  # B,nspk,T or nspk,T

        if input.dim() == 1:
            assert est_source.size(0) == self.num_spk, est_source.size(0)
        else:
            assert est_source.size(1) == self.num_spk, est_source.size(1)

        est_source = [es for es in est_source.transpose(0, 1)]  # List(M,T)
        masks = OrderedDict(
            zip(["mask_spk{}".format(i + 1) for i in range(self.num_spk)], est_source)
        )
        return est_source, ilens, masks

    def forward_rawwav(
        self, input: torch.Tensor, ilens: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Output with waveforms. """
        return self.forward(input, ilens)

    @property
    def num_spk(self):
        return self._num_spk


if __name__ == "__main__":
    mixture = torch.randn(3, 16000)
    print("mixture shape", mixture.shape)

    net = AsteroidModel_Converter(
        model_name="ConvTasNet",
        encoder_output_dim=1,
        num_spk=2,
        loss_type="si_snr",
        pretrained_path="mpariente/ConvTasNet_WHAM!_sepclean",
    )
    print("model", net)
    output, *__ = net(mixture)
    output, *__ = net.forward_rawwav(mixture, 111)
    print("output spk1 shape", output[0].shape)

    net = AsteroidModel_Converter(
        encoder_output_dim=1,
        num_spk=2,
        model_name="ConvTasNet",
        n_src=2,
        loss_type="si_snr",
        out_chan=None,
        n_blocks=2,
        n_repeats=2,
        bn_chan=128,
        hid_chan=512,
        skip_chan=128,
        conv_kernel_size=3,
        norm_type="gLN",
        mask_act="sigmoid",
        in_chan=None,
        fb_name="free",
        kernel_size=16,
        n_filters=512,
        stride=8,
        encoder_activation=None,
        sample_rate=8000,
    )
    print("\n\nmodel", net)
    output, *__ = net(mixture)
    print("output spk1 shape", output[0].shape)
    print("Finished", output[0].shape)