jhtonyKoo commited on
Commit
6557f75
1 Parent(s): 940f782
Files changed (2) hide show
  1. modules/filter.py +161 -0
  2. requirements.txt +2 -2
modules/filter.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import warnings
4
+
5
+ # https://github.com/pytorch/audio/blob/d9942bae249329bd8c8bf5c92f0f108595fcb84f/torchaudio/functional/functional.py#L495
6
+
7
+
8
+ def _create_triangular_filterbank(
9
+ all_freqs: torch.Tensor,
10
+ f_pts: torch.Tensor,
11
+ ) -> torch.Tensor:
12
+ """Create a triangular filter bank.
13
+
14
+ Args:
15
+ all_freqs (Tensor): STFT freq points of size (`n_freqs`).
16
+ f_pts (Tensor): Filter mid points of size (`n_filter`).
17
+
18
+ Returns:
19
+ fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
20
+ """
21
+ # Adopted from Librosa
22
+ # calculate the difference between each filter mid point and each stft freq point in hertz
23
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
24
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
25
+ # create overlapping triangles
26
+ zero = torch.zeros(1)
27
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
28
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
29
+ fb = torch.max(zero, torch.min(down_slopes, up_slopes))
30
+
31
+ return fb
32
+
33
+
34
+ # https://github.com/pytorch/audio/blob/d9942bae249329bd8c8bf5c92f0f108595fcb84f/torchaudio/prototype/functional/functional.py#L6
35
+
36
+
37
+ def _hz_to_bark(freqs: float, bark_scale: str = "traunmuller") -> float:
38
+ r"""Convert Hz to Barks.
39
+
40
+ Args:
41
+ freqs (float): Frequencies in Hz
42
+ bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
43
+
44
+ Returns:
45
+ barks (float): Frequency in Barks
46
+ """
47
+
48
+ if bark_scale not in ["schroeder", "traunmuller", "wang"]:
49
+ raise ValueError(
50
+ 'bark_scale should be one of "schroeder", "traunmuller" or "wang".'
51
+ )
52
+
53
+ if bark_scale == "wang":
54
+ return 6.0 * math.asinh(freqs / 600.0)
55
+ elif bark_scale == "schroeder":
56
+ return 7.0 * math.asinh(freqs / 650.0)
57
+ # Traunmuller Bark scale
58
+ barks = ((26.81 * freqs) / (1960.0 + freqs)) - 0.53
59
+ # Bark value correction
60
+ if barks < 2:
61
+ barks += 0.15 * (2 - barks)
62
+ elif barks > 20.1:
63
+ barks += 0.22 * (barks - 20.1)
64
+
65
+ return barks
66
+
67
+
68
+ def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.Tensor:
69
+ """Convert bark bin numbers to frequencies.
70
+
71
+ Args:
72
+ barks (torch.Tensor): Bark frequencies
73
+ bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
74
+
75
+ Returns:
76
+ freqs (torch.Tensor): Barks converted in Hz
77
+ """
78
+
79
+ if bark_scale not in ["schroeder", "traunmuller", "wang"]:
80
+ raise ValueError(
81
+ 'bark_scale should be one of "traunmuller", "schroeder" or "wang".'
82
+ )
83
+
84
+ if bark_scale == "wang":
85
+ return 600.0 * torch.sinh(barks / 6.0)
86
+ elif bark_scale == "schroeder":
87
+ return 650.0 * torch.sinh(barks / 7.0)
88
+ # Bark value correction
89
+ if any(barks < 2):
90
+ idx = barks < 2
91
+ barks[idx] = (barks[idx] - 0.3) / 0.85
92
+ elif any(barks > 20.1):
93
+ idx = barks > 20.1
94
+ barks[idx] = (barks[idx] + 4.422) / 1.22
95
+
96
+ # Traunmuller Bark scale
97
+ freqs = 1960 * ((barks + 0.53) / (26.28 - barks))
98
+
99
+ return freqs
100
+
101
+
102
+ def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12):
103
+ a440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
104
+ return torch.log2(freqs / (a440 / 16))
105
+
106
+
107
+ def barkscale_fbanks(
108
+ n_freqs: int,
109
+ f_min: float,
110
+ f_max: float,
111
+ n_barks: int,
112
+ sample_rate: int,
113
+ bark_scale: str = "traunmuller",
114
+ ) -> torch.Tensor:
115
+ r"""Create a frequency bin conversion matrix.
116
+
117
+ .. devices:: CPU
118
+
119
+ .. properties:: TorchScript
120
+
121
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/bark_fbanks.png
122
+ :alt: Visualization of generated filter bank
123
+
124
+ Args:
125
+ n_freqs (int): Number of frequencies to highlight/apply
126
+ f_min (float): Minimum frequency (Hz)
127
+ f_max (float): Maximum frequency (Hz)
128
+ n_barks (int): Number of mel filterbanks
129
+ sample_rate (int): Sample rate of the audio waveform
130
+ bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
131
+
132
+ Returns:
133
+ torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_barks``)
134
+ meaning number of frequencies to highlight/apply to x the number of filterbanks.
135
+ Each column is a filterbank so that assuming there is a matrix A of
136
+ size (..., ``n_freqs``), the applied result would be
137
+ ``A * barkscale_fbanks(A.size(-1), ...)``.
138
+
139
+ """
140
+
141
+ # freq bins
142
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
143
+
144
+ # calculate bark freq bins
145
+ m_min = _hz_to_bark(f_min, bark_scale=bark_scale)
146
+ m_max = _hz_to_bark(f_max, bark_scale=bark_scale)
147
+
148
+ m_pts = torch.linspace(m_min, m_max, n_barks + 2)
149
+ f_pts = _bark_to_hz(m_pts, bark_scale=bark_scale)
150
+
151
+ # create filterbank
152
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
153
+
154
+ if (fb.max(dim=0).values == 0.0).any():
155
+ warnings.warn(
156
+ "At least one bark filterbank has all zero values. "
157
+ f"The value for `n_barks` ({n_barks}) may be set too high. "
158
+ f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
159
+ )
160
+
161
+ return fb
requirements.txt CHANGED
@@ -6,7 +6,7 @@ pytube==15.0.0
6
  librosa==0.10.2
7
  scipy==1.11.3
8
  numba==0.58.1
9
- soxbindings==1.2.3
10
  auraloss==0.4.0
11
  dasp-pytorch==0.0.1
12
- torchcomp==0.1.3
 
 
6
  librosa==0.10.2
7
  scipy==1.11.3
8
  numba==0.58.1
 
9
  auraloss==0.4.0
10
  dasp-pytorch==0.0.1
11
+ torchcomp==0.1.3
12
+ pytorch-lightning==2.4.0