|
import logging |
|
import time |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.utils.parametrize import remove_parametrizations |
|
from torchaudio.functional import resample |
|
from torchaudio.transforms import MelSpectrogram |
|
from tqdm import trange |
|
|
|
from modules import config |
|
from modules.devices import devices |
|
|
|
from .hparams import HParams |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@torch.inference_mode() |
|
def inference_chunk( |
|
model, |
|
dwav: torch.Tensor, |
|
sr: int, |
|
device: torch.device, |
|
dtype: torch.dtype, |
|
npad=441, |
|
) -> torch.Tensor: |
|
assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz" |
|
del sr |
|
|
|
length = dwav.shape[-1] |
|
abs_max = dwav.abs().max().clamp(min=1e-7) |
|
|
|
assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D" |
|
dwav = dwav.to(device=device, dtype=dtype) |
|
dwav = dwav / abs_max |
|
dwav = F.pad(dwav, (0, npad)) |
|
hwav: torch.Tensor = model(dwav[None])[0].cpu() |
|
hwav = hwav[:length] |
|
hwav = hwav * abs_max |
|
|
|
return hwav |
|
|
|
|
|
def compute_corr(x, y): |
|
return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs() |
|
|
|
|
|
def compute_offset(chunk1, chunk2, sr=44100): |
|
""" |
|
Args: |
|
chunk1: (T,) |
|
chunk2: (T,) |
|
Returns: |
|
offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset) |
|
""" |
|
hop_length = sr // 200 |
|
win_length = hop_length * 4 |
|
n_fft = 2 ** (win_length - 1).bit_length() |
|
|
|
mel_fn = MelSpectrogram( |
|
sample_rate=sr, |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
n_mels=80, |
|
f_min=0.0, |
|
f_max=sr // 2, |
|
) |
|
|
|
chunk1 = chunk1.float() |
|
chunk2 = chunk2.float() |
|
|
|
spec1 = mel_fn(chunk1).log1p() |
|
spec2 = mel_fn(chunk2).log1p() |
|
|
|
corr = compute_corr(spec1, spec2) |
|
corr = corr.mean(dim=0) |
|
|
|
argmax = corr.argmax().item() |
|
|
|
if argmax > len(corr) // 2: |
|
argmax -= len(corr) |
|
|
|
offset = -argmax * hop_length |
|
|
|
return offset |
|
|
|
|
|
def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None): |
|
signal_length = (len(chunks) - 1) * hop_length + chunk_length |
|
overlap_length = chunk_length - hop_length |
|
signal = torch.zeros(signal_length, device=chunks[0].device) |
|
|
|
fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device) |
|
fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)]) |
|
fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device) |
|
fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout]) |
|
|
|
for i, chunk in enumerate(chunks): |
|
start = i * hop_length |
|
end = start + chunk_length |
|
|
|
if len(chunk) < chunk_length: |
|
chunk = F.pad(chunk, (0, chunk_length - len(chunk))) |
|
|
|
if i > 0: |
|
pre_region = chunks[i - 1][-overlap_length:] |
|
cur_region = chunk[:overlap_length] |
|
offset = compute_offset(pre_region, cur_region, sr=sr) |
|
start -= offset |
|
end -= offset |
|
|
|
if i == 0: |
|
chunk = chunk * fadeout |
|
elif i == len(chunks) - 1: |
|
chunk = chunk * fadein |
|
else: |
|
chunk = chunk * fadein * fadeout |
|
|
|
signal[start:end] += chunk[: len(signal[start:end])] |
|
|
|
signal = signal[:length] |
|
|
|
return signal |
|
|
|
|
|
def remove_weight_norm_recursively(module): |
|
for _, module in module.named_modules(): |
|
try: |
|
remove_parametrizations(module, "weight") |
|
except Exception: |
|
pass |
|
|
|
|
|
def inference( |
|
model, |
|
dwav, |
|
sr, |
|
device, |
|
dtype, |
|
chunk_seconds: float = 30.0, |
|
overlap_seconds: float = 1.0, |
|
): |
|
if config.runtime_env_vars.off_tqdm: |
|
trange = range |
|
else: |
|
from tqdm import trange |
|
|
|
remove_weight_norm_recursively(model) |
|
|
|
hp: HParams = model.hp |
|
|
|
dwav = resample( |
|
dwav, |
|
orig_freq=sr, |
|
new_freq=hp.wav_rate, |
|
lowpass_filter_width=64, |
|
rolloff=0.9475937167399596, |
|
resampling_method="sinc_interp_kaiser", |
|
beta=14.769656459379492, |
|
) |
|
|
|
del sr |
|
|
|
sr = hp.wav_rate |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
|
|
start_time = time.perf_counter() |
|
|
|
chunk_length = int(sr * chunk_seconds) |
|
overlap_length = int(sr * overlap_seconds) |
|
hop_length = chunk_length - overlap_length |
|
|
|
chunks = [] |
|
for start in trange(0, dwav.shape[-1], hop_length): |
|
chunk_dwav = inference_chunk( |
|
model, dwav[start : start + chunk_length], sr, device, dtype |
|
) |
|
chunks.append(chunk_dwav.cpu()) |
|
devices.torch_gc() |
|
|
|
hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1]) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
|
|
elapsed_time = time.perf_counter() - start_time |
|
logger.info( |
|
f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz" |
|
) |
|
devices.torch_gc() |
|
|
|
return hwav, sr |
|
|