zhzluke96
update
32b2aaa
raw
history blame
1.19 kB
from pathlib import Path
from typing import Callable
from torch import Tensor
def walk_paths(root, suffix):
for path in Path(root).iterdir():
if path.is_dir():
yield from walk_paths(path, suffix)
elif path.suffix == suffix:
yield path
def rglob_audio_files(path: Path):
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
"""
Args:
fg: (b, t)
bg: (b, t)
"""
assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
fg = fg / (fg_energy + eps).sqrt()
bg = bg / (bg_energy + eps).sqrt()
if callable(alpha):
alpha = alpha()
assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
mx = alpha * fg + (1 - alpha) * bg
mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
return mx