File size: 1,699 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
from collections import OrderedDict

from models.svc.base import SVCInference
from modules.encoder.condition_encoder import ConditionEncoder
from models.svc.transformer.transformer import Transformer
from models.svc.transformer.conformer import Conformer


class TransformerInference(SVCInference):
    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
        SVCInference.__init__(self, args, cfg, infer_type)

    def _build_model(self):
        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
        if self.cfg.model.transformer.type == "transformer":
            self.acoustic_mapper = Transformer(self.cfg.model.transformer)
        elif self.cfg.model.transformer.type == "conformer":
            self.acoustic_mapper = Conformer(self.cfg.model.transformer)
        else:
            raise NotImplementedError
        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
        return model

    def _inference_each_batch(self, batch_data):
        device = self.accelerator.device
        for k, v in batch_data.items():
            batch_data[k] = v.to(device)

        condition = self.condition_encoder(batch_data)
        y_pred = self.acoustic_mapper(condition, batch_data["mask"])

        return y_pred