File size: 2,473 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn

from modules.diffusion import BiDilConv
from modules.encoder.position_encoder import PositionEncoder


class DiffusionWrapper(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.diff_cfg = cfg.model.diffusion

        self.diff_encoder = PositionEncoder(
            d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding,
            d_out=self.diff_cfg.bidilconv.base_channel,
            d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer,
            activation_function=self.diff_cfg.step_encoder.activation,
            n_layer=self.diff_cfg.step_encoder.num_layer,
            max_period=self.diff_cfg.step_encoder.max_period,
        )

        # FIXME: Only support BiDilConv now for debug
        if self.diff_cfg.model_type.lower() == "bidilconv":
            self.neural_network = BiDilConv(
                input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv
            )
        else:
            raise ValueError(
                f"Unsupported diffusion model type: {self.diff_cfg.model_type}"
            )

    def forward(self, x, t, c):
        """
        Args:
            x: [N, T, mel_band] of mel spectrogram
            t: Diffusion time step with shape of [N]
            c: [N, T, conditioner_size] of conditioner

        Returns:
            [N, T, mel_band] of mel spectrogram
        """

        assert (
            x.size()[:-1] == c.size()[:-1]
        ), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size())
        assert x.size(0) == t.size(
            0
        ), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size())
        assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim())

        N, T, mel_band = x.size()

        x = x.transpose(1, 2).contiguous()  # [N, mel_band, T]
        c = c.transpose(1, 2).contiguous()  # [N, conditioner_size, T]
        t = self.diff_encoder(t).contiguous()  # [N, base_channel]

        h = self.neural_network(x, t, c)
        h = h.transpose(1, 2).contiguous()  # [N, T, mel_band]

        assert h.size() == (
            N,
            T,
            mel_band,
        ), "h mismatch with input x, got \n h: {} \n x: {}".format(
            h.size(), (N, T, mel_band)
        )
        return h