|
from pathlib import Path |
|
from typing import Callable, Union |
|
|
|
from torch import Tensor |
|
|
|
|
|
def walk_paths(root, suffix): |
|
for path in Path(root).iterdir(): |
|
if path.is_dir(): |
|
yield from walk_paths(path, suffix) |
|
elif path.suffix == suffix: |
|
yield path |
|
|
|
|
|
def rglob_audio_files(path: Path): |
|
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac")) |
|
|
|
|
|
def mix_fg_bg( |
|
fg: Tensor, bg: Tensor, alpha: Union[float, Callable[..., float]] = 0.5, eps=1e-7 |
|
): |
|
""" |
|
Args: |
|
fg: (b, t) |
|
bg: (b, t) |
|
""" |
|
assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}" |
|
fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps) |
|
bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps) |
|
|
|
fg_energy = fg.pow(2).sum(dim=-1, keepdim=True) |
|
bg_energy = bg.pow(2).sum(dim=-1, keepdim=True) |
|
|
|
fg = fg / (fg_energy + eps).sqrt() |
|
bg = bg / (bg_energy + eps).sqrt() |
|
|
|
if callable(alpha): |
|
alpha = alpha() |
|
|
|
assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}" |
|
|
|
mx = alpha * fg + (1 - alpha) * bg |
|
mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps) |
|
|
|
return mx |
|
|