Upload 8 files
Browse files- PLCMOS/models/plcmos_v0.onnx +0 -0
- PLCMOS/models/plcmos_v1_intrusive.onnx +0 -0
- PLCMOS/models/plcmos_v1_nonintrusive.onnx +0 -0
- PLCMOS/plc_mos.py +154 -0
- utils/__init__.py +0 -0
- utils/stft.py +23 -0
- utils/tblogger.py +71 -0
- utils/utils.py +67 -0
PLCMOS/models/plcmos_v0.onnx
ADDED
Binary file (691 kB). View file
|
|
PLCMOS/models/plcmos_v1_intrusive.onnx
ADDED
Binary file (280 kB). View file
|
|
PLCMOS/models/plcmos_v1_nonintrusive.onnx
ADDED
Binary file (129 kB). View file
|
|
PLCMOS/plc_mos.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import onnxruntime as ort
|
7 |
+
from numpy.fft import rfft
|
8 |
+
from numpy.lib.stride_tricks import as_strided
|
9 |
+
|
10 |
+
class PLCMOSEstimator():
|
11 |
+
def __init__(self, model_version=1):
|
12 |
+
"""
|
13 |
+
Initialize a PLC-MOS model of a given version. There are currently three models available, v0 (intrusive)
|
14 |
+
and v1 (both non-intrusive and intrusive available). The default is to use the v1 models.
|
15 |
+
"""
|
16 |
+
|
17 |
+
self.model_version = model_version
|
18 |
+
model_paths = [
|
19 |
+
# v0 model:
|
20 |
+
[("models/plcmos_v0.onnx", 999999999999), (None, 0)],
|
21 |
+
|
22 |
+
# v1 models:
|
23 |
+
[("models/plcmos_v1_intrusive.onnx", 768),
|
24 |
+
("models/plcmos_v1_nonintrusive.onnx", 999999999999)],
|
25 |
+
]
|
26 |
+
self.sessions = []
|
27 |
+
self.max_lens = []
|
28 |
+
options = ort.SessionOptions()
|
29 |
+
options.intra_op_num_threads = 8
|
30 |
+
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
31 |
+
for path, max_len in model_paths[model_version]:
|
32 |
+
if not path is None:
|
33 |
+
file_dir = os.path.dirname(os.path.realpath(__file__))
|
34 |
+
self.sessions.append(ort.InferenceSession(
|
35 |
+
os.path.join(file_dir, path), options))
|
36 |
+
self.max_lens.append(max_len)
|
37 |
+
else:
|
38 |
+
self.sessions.append(None)
|
39 |
+
self.max_lens.append(0)
|
40 |
+
|
41 |
+
def logpow_dns(self, sig, floor=-30.):
|
42 |
+
"""
|
43 |
+
Compute log power of complex spectrum.
|
44 |
+
|
45 |
+
Floor any -`np.inf` value to (nonzero minimum + `floor`) dB.
|
46 |
+
If all values are 0s, floor all values to -80 dB.
|
47 |
+
"""
|
48 |
+
log10e = np.log10(np.e)
|
49 |
+
pspec = sig.real ** 2 + sig.imag ** 2
|
50 |
+
zeros = pspec == 0
|
51 |
+
logp = np.empty_like(pspec)
|
52 |
+
if np.any(~zeros):
|
53 |
+
logp[~zeros] = np.log(pspec[~zeros])
|
54 |
+
logp[zeros] = np.log(pspec[~zeros].min()) + floor / 10 / log10e
|
55 |
+
else:
|
56 |
+
logp.fill(-80 / 10 / log10e)
|
57 |
+
|
58 |
+
return logp
|
59 |
+
|
60 |
+
def hop2hsize(self, wind, hop):
|
61 |
+
"""
|
62 |
+
Convert hop fraction to integer size if necessary.
|
63 |
+
"""
|
64 |
+
if hop >= 1:
|
65 |
+
assert type(hop) == int, "Hop size must be integer!"
|
66 |
+
return hop
|
67 |
+
else:
|
68 |
+
assert 0 < hop < 1, "Hop fraction has to be in range (0,1)!"
|
69 |
+
return int(len(wind) * hop)
|
70 |
+
|
71 |
+
def stana(self, sig, sr, wind, hop, synth=False, center=False):
|
72 |
+
"""
|
73 |
+
Short term analysis by windowing
|
74 |
+
"""
|
75 |
+
ssize = len(sig)
|
76 |
+
fsize = len(wind)
|
77 |
+
hsize = self.hop2hsize(wind, hop)
|
78 |
+
if synth:
|
79 |
+
sstart = hsize - fsize # int(-fsize * (1-hfrac))
|
80 |
+
elif center:
|
81 |
+
sstart = -int(len(wind) / 2) # odd window centered at exactly n=0
|
82 |
+
else:
|
83 |
+
sstart = 0
|
84 |
+
send = ssize
|
85 |
+
|
86 |
+
nframe = math.ceil((send - sstart) / hsize)
|
87 |
+
|
88 |
+
# Calculate zero-padding sizes
|
89 |
+
zpleft = -sstart
|
90 |
+
zpright = (nframe - 1) * hsize + fsize - zpleft - ssize
|
91 |
+
if zpleft > 0 or zpright > 0:
|
92 |
+
sigpad = np.zeros(ssize + zpleft + zpright, dtype=sig.dtype)
|
93 |
+
sigpad[zpleft:len(sigpad) - zpright] = sig
|
94 |
+
else:
|
95 |
+
sigpad = sig
|
96 |
+
|
97 |
+
return as_strided(sigpad, shape=(nframe, fsize),
|
98 |
+
strides=(sig.itemsize * hsize, sig.itemsize)) * wind
|
99 |
+
|
100 |
+
def stft(self, sig, sr, wind, hop, nfft):
|
101 |
+
"""
|
102 |
+
Compute STFT: window + rfft
|
103 |
+
"""
|
104 |
+
frames = self.stana(sig, sr, wind, hop, synth=True)
|
105 |
+
return rfft(frames, n=nfft)
|
106 |
+
|
107 |
+
def stft_transform(self, audio, dft_size=512, hop_fraction=0.5, sr=16000):
|
108 |
+
"""
|
109 |
+
Compute STFT parameters, then compute STFT
|
110 |
+
"""
|
111 |
+
window = np.hamming(dft_size + 1)
|
112 |
+
window = window[:-1]
|
113 |
+
amp = np.abs(self.stft(audio, sr, window, hop_fraction, dft_size))
|
114 |
+
feat = self.logpow_dns(amp, floor=-120.)
|
115 |
+
return feat / 20.
|
116 |
+
|
117 |
+
def run(self, audio_degraded, audio_clean=None, combined=False):
|
118 |
+
"""
|
119 |
+
Run the PLCMOS model and return the MOS for the given audio. If a clean audio file is passed and the
|
120 |
+
selected model version has an intrusive version, that version will be used, otherwise, the nonintrusive
|
121 |
+
model will be used. If combined is set to true (default), the mean of intrusive and nonintrusive models
|
122 |
+
results will be returned, when both are available
|
123 |
+
|
124 |
+
For intrusive models, the clean reference should be the unprocessed audio file the degraded audio is
|
125 |
+
based on. It is not required to be aligned with the degraded audio.
|
126 |
+
|
127 |
+
Audio data should be 16kHz, mono, [-1, 1] range.
|
128 |
+
"""
|
129 |
+
audio_features_degraded = np.float32(self.stft_transform(audio_degraded))[
|
130 |
+
np.newaxis, np.newaxis, ...]
|
131 |
+
assert len(
|
132 |
+
audio_features_degraded) <= self.max_lens[0], "Maximum input length exceeded"
|
133 |
+
|
134 |
+
if audio_clean is None:
|
135 |
+
combined = False
|
136 |
+
|
137 |
+
mos = 0
|
138 |
+
|
139 |
+
session = self.sessions[0]
|
140 |
+
assert not session is None, "Intrusive model not available for this model version."
|
141 |
+
audio_features_clean = np.float32(self.stft_transform(audio_clean))[
|
142 |
+
np.newaxis, np.newaxis, ...]
|
143 |
+
assert len(
|
144 |
+
audio_features_clean) <= self.max_lens[0], "Maximum input length exceeded"
|
145 |
+
onnx_inputs = {"degraded_audio": audio_features_degraded,
|
146 |
+
"clean_audio": audio_features_clean}
|
147 |
+
mos = float(session.run(None, onnx_inputs)[0])
|
148 |
+
|
149 |
+
session = self.sessions[1]
|
150 |
+
assert not session is None, "Nonintrusive model not available for this model version."
|
151 |
+
onnx_inputs = {"degraded_audio": audio_features_degraded}
|
152 |
+
mos_2 = float(session.run(None, onnx_inputs)[0])
|
153 |
+
mos = [mos, mos_2]
|
154 |
+
return mos
|
utils/__init__.py
ADDED
File without changes
|
utils/stft.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class STFTMag(nn.Module):
|
6 |
+
def __init__(self,
|
7 |
+
nfft=1024,
|
8 |
+
hop=256):
|
9 |
+
super().__init__()
|
10 |
+
self.nfft = nfft
|
11 |
+
self.hop = hop
|
12 |
+
self.register_buffer('window', torch.hann_window(nfft), False)
|
13 |
+
|
14 |
+
# x: [B,T] or [T]
|
15 |
+
@torch.no_grad()
|
16 |
+
def forward(self, x):
|
17 |
+
stft = torch.stft(x.cpu(),
|
18 |
+
self.nfft,
|
19 |
+
self.hop,
|
20 |
+
window=self.window,
|
21 |
+
) # return_complex=False) #[B, F, TT,2]
|
22 |
+
mag = torch.norm(stft, p=2, dim=-1) # [B, F, TT]
|
23 |
+
return mag
|
utils/tblogger.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path
|
2 |
+
|
3 |
+
import librosa as rosa
|
4 |
+
import matplotlib
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
8 |
+
from pytorch_lightning.utilities import rank_zero_only
|
9 |
+
|
10 |
+
from utils.stft import STFTMag
|
11 |
+
|
12 |
+
matplotlib.use('Agg')
|
13 |
+
|
14 |
+
|
15 |
+
class TensorBoardLoggerExpanded(TensorBoardLogger):
|
16 |
+
def __init__(self, sr=16000):
|
17 |
+
super().__init__(save_dir='lightning_logs', default_hp_metric=False, name='')
|
18 |
+
self.sr = sr
|
19 |
+
self.stftmag = STFTMag()
|
20 |
+
|
21 |
+
def fig2np(self, fig):
|
22 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
23 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
24 |
+
return data
|
25 |
+
|
26 |
+
def plot_spectrogram_to_numpy(self, y, y_low, y_recon, step):
|
27 |
+
name_list = ['y', 'y_low', 'y_recon']
|
28 |
+
fig = plt.figure(figsize=(9, 15))
|
29 |
+
fig.suptitle(f'Epoch_{step}')
|
30 |
+
for i, yy in enumerate([y, y_low, y_recon]):
|
31 |
+
if yy.dim() == 1:
|
32 |
+
yy = self.stftmag(yy)
|
33 |
+
ax = plt.subplot(3, 1, i + 1)
|
34 |
+
ax.set_title(name_list[i])
|
35 |
+
plt.imshow(rosa.amplitude_to_db(yy.numpy(),
|
36 |
+
ref=np.max, top_db=80.),
|
37 |
+
# vmin = -20,
|
38 |
+
vmax=0.,
|
39 |
+
aspect='auto',
|
40 |
+
origin='lower',
|
41 |
+
interpolation='none')
|
42 |
+
plt.colorbar()
|
43 |
+
plt.xlabel('Frames')
|
44 |
+
plt.ylabel('Channels')
|
45 |
+
plt.tight_layout()
|
46 |
+
|
47 |
+
fig.canvas.draw()
|
48 |
+
data = self.fig2np(fig)
|
49 |
+
|
50 |
+
plt.close()
|
51 |
+
return data
|
52 |
+
|
53 |
+
@rank_zero_only
|
54 |
+
def log_spectrogram(self, y, y_low, y_recon, epoch):
|
55 |
+
y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu()
|
56 |
+
spec_img = self.plot_spectrogram_to_numpy(y, y_low, y_recon, epoch)
|
57 |
+
self.experiment.add_image(path.join(self.save_dir, 'result'),
|
58 |
+
spec_img,
|
59 |
+
epoch,
|
60 |
+
dataformats='HWC')
|
61 |
+
self.experiment.flush()
|
62 |
+
return
|
63 |
+
|
64 |
+
@rank_zero_only
|
65 |
+
def log_audio(self, y, y_low, y_recon, epoch):
|
66 |
+
y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu(),
|
67 |
+
name_list = ['y', 'y_low', 'y_recon']
|
68 |
+
for n, yy in zip(name_list, [y, y_low, y_recon]):
|
69 |
+
self.experiment.add_audio(n, yy, epoch, self.sr)
|
70 |
+
self.experiment.flush()
|
71 |
+
return
|
utils/utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import librosa.display
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
8 |
+
|
9 |
+
from config import CONFIG
|
10 |
+
|
11 |
+
|
12 |
+
def mkdir_p(mypath):
|
13 |
+
"""Creates a directory. equivalent to using mkdir -p on the command line"""
|
14 |
+
|
15 |
+
from errno import EEXIST
|
16 |
+
from os import makedirs, path
|
17 |
+
|
18 |
+
try:
|
19 |
+
makedirs(mypath)
|
20 |
+
except OSError as exc: # Python >2.5
|
21 |
+
if exc.errno == EEXIST and path.isdir(mypath):
|
22 |
+
pass
|
23 |
+
else:
|
24 |
+
raise
|
25 |
+
|
26 |
+
|
27 |
+
def visualize(target, input, recon, path):
|
28 |
+
sr = CONFIG.DATA.sr
|
29 |
+
window_size = 1024
|
30 |
+
window = np.hanning(window_size)
|
31 |
+
|
32 |
+
stft_hr = librosa.core.spectrum.stft(target, n_fft=window_size, hop_length=512, window=window)
|
33 |
+
stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
|
34 |
+
|
35 |
+
stft_lr = librosa.core.spectrum.stft(input, n_fft=window_size, hop_length=512, window=window)
|
36 |
+
stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
|
37 |
+
|
38 |
+
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
|
39 |
+
stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
|
40 |
+
|
41 |
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
|
42 |
+
ax1.title.set_text('Target signal')
|
43 |
+
ax2.title.set_text('Lossy signal')
|
44 |
+
ax3.title.set_text('Reconstructed signal')
|
45 |
+
|
46 |
+
canvas = FigureCanvas(fig)
|
47 |
+
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr)
|
48 |
+
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr)
|
49 |
+
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr)
|
50 |
+
mkdir_p(path)
|
51 |
+
fig.savefig(os.path.join(path, 'spec.png'))
|
52 |
+
|
53 |
+
|
54 |
+
def get_power(x, nfft):
|
55 |
+
S = librosa.stft(x, n_fft=nfft)
|
56 |
+
S = np.log(np.abs(S) ** 2 + 1e-8)
|
57 |
+
return S
|
58 |
+
|
59 |
+
|
60 |
+
def LSD(x_hr, x_pr):
|
61 |
+
S1 = get_power(x_hr, nfft=2048)
|
62 |
+
S2 = get_power(x_pr, nfft=2048)
|
63 |
+
lsd = np.mean(np.sqrt(np.mean((S1 - S2) ** 2 + 1e-8, axis=-1)), axis=0)
|
64 |
+
S1 = S1[-(len(S1) - 1) // 2:, :]
|
65 |
+
S2 = S2[-(len(S2) - 1) // 2:, :]
|
66 |
+
lsd_high = np.mean(np.sqrt(np.mean((S1 - S2) ** 2 + 1e-8, axis=-1)), axis=0)
|
67 |
+
return lsd, lsd_high
|