File size: 5,118 Bytes
32b2aaa
 
 
 
 
 
 
 
 
 
d2b7e94
627d3d7
 
32b2aaa
 
 
 
 
 
627d3d7
 
 
 
 
 
 
 
32b2aaa
 
 
 
 
 
 
627d3d7
32b2aaa
 
627d3d7
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627d3d7
 
 
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b473486
627d3d7
 
 
 
 
 
 
b473486
 
 
ae79826
 
b473486
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627d3d7
 
b473486
627d3d7
 
32b2aaa
 
 
 
 
 
 
b473486
 
 
627d3d7
32b2aaa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import logging
import time

import torch
import torch.nn.functional as F
from torch.nn.utils.parametrize import remove_parametrizations
from torchaudio.functional import resample
from torchaudio.transforms import MelSpectrogram
from tqdm import trange

from modules import config
from modules.devices import devices

from .hparams import HParams

logger = logging.getLogger(__name__)


@torch.inference_mode()
def inference_chunk(
    model,
    dwav: torch.Tensor,
    sr: int,
    device: torch.device,
    dtype: torch.dtype,
    npad=441,
) -> torch.Tensor:
    assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
    del sr

    length = dwav.shape[-1]
    abs_max = dwav.abs().max().clamp(min=1e-7)

    assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
    dwav = dwav.to(device=device, dtype=dtype)
    dwav = dwav / abs_max  # Normalize
    dwav = F.pad(dwav, (0, npad))
    hwav: torch.Tensor = model(dwav[None])[0].cpu()  # (T,)
    hwav = hwav[:length]  # Trim padding
    hwav = hwav * abs_max  # Unnormalize

    return hwav


def compute_corr(x, y):
    return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()


def compute_offset(chunk1, chunk2, sr=44100):
    """
    Args:
        chunk1: (T,)
        chunk2: (T,)
    Returns:
        offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
    """
    hop_length = sr // 200  # 5 ms resolution
    win_length = hop_length * 4
    n_fft = 2 ** (win_length - 1).bit_length()

    mel_fn = MelSpectrogram(
        sample_rate=sr,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        n_mels=80,
        f_min=0.0,
        f_max=sr // 2,
    )

    chunk1 = chunk1.float()
    chunk2 = chunk2.float()

    spec1 = mel_fn(chunk1).log1p()
    spec2 = mel_fn(chunk2).log1p()

    corr = compute_corr(spec1, spec2)  # (F, T)
    corr = corr.mean(dim=0)  # (T,)

    argmax = corr.argmax().item()

    if argmax > len(corr) // 2:
        argmax -= len(corr)

    offset = -argmax * hop_length

    return offset


def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
    signal_length = (len(chunks) - 1) * hop_length + chunk_length
    overlap_length = chunk_length - hop_length
    signal = torch.zeros(signal_length, device=chunks[0].device)

    fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
    fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
    fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
    fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])

    for i, chunk in enumerate(chunks):
        start = i * hop_length
        end = start + chunk_length

        if len(chunk) < chunk_length:
            chunk = F.pad(chunk, (0, chunk_length - len(chunk)))

        if i > 0:
            pre_region = chunks[i - 1][-overlap_length:]
            cur_region = chunk[:overlap_length]
            offset = compute_offset(pre_region, cur_region, sr=sr)
            start -= offset
            end -= offset

        if i == 0:
            chunk = chunk * fadeout
        elif i == len(chunks) - 1:
            chunk = chunk * fadein
        else:
            chunk = chunk * fadein * fadeout

        signal[start:end] += chunk[: len(signal[start:end])]

    signal = signal[:length]

    return signal


def remove_weight_norm_recursively(module):
    for _, module in module.named_modules():
        try:
            remove_parametrizations(module, "weight")
        except Exception:
            pass


def inference(
    model,
    dwav,
    sr,
    device,
    dtype,
    chunk_seconds: float = 30.0,
    overlap_seconds: float = 1.0,
):
    if config.runtime_env_vars.off_tqdm:
        trange = range
    else:
        from tqdm import trange

    remove_weight_norm_recursively(model)

    hp: HParams = model.hp

    dwav = resample(
        dwav,
        orig_freq=sr,
        new_freq=hp.wav_rate,
        lowpass_filter_width=64,
        rolloff=0.9475937167399596,
        resampling_method="sinc_interp_kaiser",
        beta=14.769656459379492,
    )

    del sr  # Everything is in hp.wav_rate now

    sr = hp.wav_rate

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    start_time = time.perf_counter()

    chunk_length = int(sr * chunk_seconds)
    overlap_length = int(sr * overlap_seconds)
    hop_length = chunk_length - overlap_length

    chunks = []
    for start in trange(0, dwav.shape[-1], hop_length):
        chunk_dwav = inference_chunk(
            model, dwav[start : start + chunk_length], sr, device, dtype
        )
        chunks.append(chunk_dwav.cpu())
        devices.torch_gc()

    hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    elapsed_time = time.perf_counter() - start_time
    logger.info(
        f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
    )
    devices.torch_gc()

    return hwav, sr