Spaces:
Running
Running
modify
Browse files- modules/filter.py +161 -0
- requirements.txt +2 -2
modules/filter.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
# https://github.com/pytorch/audio/blob/d9942bae249329bd8c8bf5c92f0f108595fcb84f/torchaudio/functional/functional.py#L495
|
6 |
+
|
7 |
+
|
8 |
+
def _create_triangular_filterbank(
|
9 |
+
all_freqs: torch.Tensor,
|
10 |
+
f_pts: torch.Tensor,
|
11 |
+
) -> torch.Tensor:
|
12 |
+
"""Create a triangular filter bank.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
all_freqs (Tensor): STFT freq points of size (`n_freqs`).
|
16 |
+
f_pts (Tensor): Filter mid points of size (`n_filter`).
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
|
20 |
+
"""
|
21 |
+
# Adopted from Librosa
|
22 |
+
# calculate the difference between each filter mid point and each stft freq point in hertz
|
23 |
+
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
|
24 |
+
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
|
25 |
+
# create overlapping triangles
|
26 |
+
zero = torch.zeros(1)
|
27 |
+
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
|
28 |
+
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
|
29 |
+
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
|
30 |
+
|
31 |
+
return fb
|
32 |
+
|
33 |
+
|
34 |
+
# https://github.com/pytorch/audio/blob/d9942bae249329bd8c8bf5c92f0f108595fcb84f/torchaudio/prototype/functional/functional.py#L6
|
35 |
+
|
36 |
+
|
37 |
+
def _hz_to_bark(freqs: float, bark_scale: str = "traunmuller") -> float:
|
38 |
+
r"""Convert Hz to Barks.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
freqs (float): Frequencies in Hz
|
42 |
+
bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
barks (float): Frequency in Barks
|
46 |
+
"""
|
47 |
+
|
48 |
+
if bark_scale not in ["schroeder", "traunmuller", "wang"]:
|
49 |
+
raise ValueError(
|
50 |
+
'bark_scale should be one of "schroeder", "traunmuller" or "wang".'
|
51 |
+
)
|
52 |
+
|
53 |
+
if bark_scale == "wang":
|
54 |
+
return 6.0 * math.asinh(freqs / 600.0)
|
55 |
+
elif bark_scale == "schroeder":
|
56 |
+
return 7.0 * math.asinh(freqs / 650.0)
|
57 |
+
# Traunmuller Bark scale
|
58 |
+
barks = ((26.81 * freqs) / (1960.0 + freqs)) - 0.53
|
59 |
+
# Bark value correction
|
60 |
+
if barks < 2:
|
61 |
+
barks += 0.15 * (2 - barks)
|
62 |
+
elif barks > 20.1:
|
63 |
+
barks += 0.22 * (barks - 20.1)
|
64 |
+
|
65 |
+
return barks
|
66 |
+
|
67 |
+
|
68 |
+
def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.Tensor:
|
69 |
+
"""Convert bark bin numbers to frequencies.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
barks (torch.Tensor): Bark frequencies
|
73 |
+
bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
freqs (torch.Tensor): Barks converted in Hz
|
77 |
+
"""
|
78 |
+
|
79 |
+
if bark_scale not in ["schroeder", "traunmuller", "wang"]:
|
80 |
+
raise ValueError(
|
81 |
+
'bark_scale should be one of "traunmuller", "schroeder" or "wang".'
|
82 |
+
)
|
83 |
+
|
84 |
+
if bark_scale == "wang":
|
85 |
+
return 600.0 * torch.sinh(barks / 6.0)
|
86 |
+
elif bark_scale == "schroeder":
|
87 |
+
return 650.0 * torch.sinh(barks / 7.0)
|
88 |
+
# Bark value correction
|
89 |
+
if any(barks < 2):
|
90 |
+
idx = barks < 2
|
91 |
+
barks[idx] = (barks[idx] - 0.3) / 0.85
|
92 |
+
elif any(barks > 20.1):
|
93 |
+
idx = barks > 20.1
|
94 |
+
barks[idx] = (barks[idx] + 4.422) / 1.22
|
95 |
+
|
96 |
+
# Traunmuller Bark scale
|
97 |
+
freqs = 1960 * ((barks + 0.53) / (26.28 - barks))
|
98 |
+
|
99 |
+
return freqs
|
100 |
+
|
101 |
+
|
102 |
+
def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12):
|
103 |
+
a440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
|
104 |
+
return torch.log2(freqs / (a440 / 16))
|
105 |
+
|
106 |
+
|
107 |
+
def barkscale_fbanks(
|
108 |
+
n_freqs: int,
|
109 |
+
f_min: float,
|
110 |
+
f_max: float,
|
111 |
+
n_barks: int,
|
112 |
+
sample_rate: int,
|
113 |
+
bark_scale: str = "traunmuller",
|
114 |
+
) -> torch.Tensor:
|
115 |
+
r"""Create a frequency bin conversion matrix.
|
116 |
+
|
117 |
+
.. devices:: CPU
|
118 |
+
|
119 |
+
.. properties:: TorchScript
|
120 |
+
|
121 |
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/bark_fbanks.png
|
122 |
+
:alt: Visualization of generated filter bank
|
123 |
+
|
124 |
+
Args:
|
125 |
+
n_freqs (int): Number of frequencies to highlight/apply
|
126 |
+
f_min (float): Minimum frequency (Hz)
|
127 |
+
f_max (float): Maximum frequency (Hz)
|
128 |
+
n_barks (int): Number of mel filterbanks
|
129 |
+
sample_rate (int): Sample rate of the audio waveform
|
130 |
+
bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_barks``)
|
134 |
+
meaning number of frequencies to highlight/apply to x the number of filterbanks.
|
135 |
+
Each column is a filterbank so that assuming there is a matrix A of
|
136 |
+
size (..., ``n_freqs``), the applied result would be
|
137 |
+
``A * barkscale_fbanks(A.size(-1), ...)``.
|
138 |
+
|
139 |
+
"""
|
140 |
+
|
141 |
+
# freq bins
|
142 |
+
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
143 |
+
|
144 |
+
# calculate bark freq bins
|
145 |
+
m_min = _hz_to_bark(f_min, bark_scale=bark_scale)
|
146 |
+
m_max = _hz_to_bark(f_max, bark_scale=bark_scale)
|
147 |
+
|
148 |
+
m_pts = torch.linspace(m_min, m_max, n_barks + 2)
|
149 |
+
f_pts = _bark_to_hz(m_pts, bark_scale=bark_scale)
|
150 |
+
|
151 |
+
# create filterbank
|
152 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
153 |
+
|
154 |
+
if (fb.max(dim=0).values == 0.0).any():
|
155 |
+
warnings.warn(
|
156 |
+
"At least one bark filterbank has all zero values. "
|
157 |
+
f"The value for `n_barks` ({n_barks}) may be set too high. "
|
158 |
+
f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
|
159 |
+
)
|
160 |
+
|
161 |
+
return fb
|
requirements.txt
CHANGED
@@ -6,7 +6,7 @@ pytube==15.0.0
|
|
6 |
librosa==0.10.2
|
7 |
scipy==1.11.3
|
8 |
numba==0.58.1
|
9 |
-
soxbindings==1.2.3
|
10 |
auraloss==0.4.0
|
11 |
dasp-pytorch==0.0.1
|
12 |
-
torchcomp==0.1.3
|
|
|
|
6 |
librosa==0.10.2
|
7 |
scipy==1.11.3
|
8 |
numba==0.58.1
|
|
|
9 |
auraloss==0.4.0
|
10 |
dasp-pytorch==0.0.1
|
11 |
+
torchcomp==0.1.3
|
12 |
+
pytorch-lightning==2.4.0
|