|
import streamlit as st |
|
import librosa |
|
import librosa.display |
|
from config import CONFIG |
|
import torch |
|
from dataset import MaskGenerator |
|
import onnxruntime, onnx |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
path = 'lightning_logs/version_0/checkpoints/frn.onnx' |
|
onnx_model = onnx.load(path) |
|
options = onnxruntime.SessionOptions() |
|
options.intra_op_num_threads = 2 |
|
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
session = onnxruntime.InferenceSession(path, options) |
|
input_names = [x.name for x in session.get_inputs()] |
|
output_names = [x.name for x in session.get_outputs()] |
|
return session, onnx_model, input_names, output_names |
|
|
|
def inference(re_im, session, onnx_model, input_names, output_names): |
|
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], |
|
dtype=np.float32) |
|
for i, _input in enumerate(onnx_model.graph.input) |
|
} |
|
|
|
output_audio = [] |
|
for t in range(re_im.shape[0]): |
|
inputs[input_names[0]] = re_im[t] |
|
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) |
|
inputs[input_names[1]] = prev_mag |
|
inputs[input_names[2]] = predictor_state |
|
inputs[input_names[3]] = mlp_state |
|
output_audio.append(out) |
|
|
|
output_audio = torch.tensor(np.concatenate(output_audio, 0)) |
|
output_audio = output_audio.permute(1, 0, 2).contiguous() |
|
output_audio = torch.view_as_complex(output_audio) |
|
output_audio = torch.istft(output_audio, window, stride, window=hann) |
|
return output_audio.numpy() |
|
|
|
def visualize(hr, lr, recon): |
|
sr = CONFIG.DATA.sr |
|
window_size = 1024 |
|
window = np.hanning(window_size) |
|
|
|
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) |
|
stft_hr = 2 * np.abs(stft_hr) / np.sum(window) |
|
|
|
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) |
|
stft_lr = 2 * np.abs(stft_lr) / np.sum(window) |
|
|
|
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) |
|
stft_recon = 2 * np.abs(stft_recon) / np.sum(window) |
|
|
|
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10)) |
|
ax1.title.set_text('Target signal') |
|
ax2.title.set_text('Lossy signal') |
|
ax3.title.set_text('Enhanced signal') |
|
|
|
canvas = FigureCanvas(fig) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr) |
|
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr) |
|
return fig |
|
|
|
packet_size = CONFIG.DATA.EVAL.packet_size |
|
window = CONFIG.DATA.window_size |
|
stride = CONFIG.DATA.stride |
|
|
|
title = 'Packet Loss Concealment' |
|
st.set_page_config(page_title=title, page_icon=":sound:") |
|
st.title(title) |
|
|
|
uploaded_file = st.file_uploader("Upload your audio file (.wav)") |
|
|
|
is_file_uploaded = uploaded_file is not None |
|
if not is_file_uploaded: |
|
uploaded_file = 'sample.wav' |
|
|
|
target, sr = librosa.load(uploaded_file, sr=48000) |
|
target = target[:packet_size * (len(target) // packet_size)] |
|
|
|
st.subheader('Original audio') |
|
st.audio(uploaded_file) |
|
|
|
st.subheader('Choose loss packet percentage') |
|
loss_percent = st.radio('Loss percentage', ['10%', '20%', '30%', '40%']) |
|
loss_percent = float(loss_percent[:-1])/100 |
|
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) |
|
lossy_input = target.copy().reshape(-1, packet_size) |
|
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] |
|
lossy_input *= mask |
|
lossy_input = lossy_input.reshape(-1) |
|
hann = torch.sqrt(torch.hann_window(window)) |
|
lossy_input_tensor = torch.tensor(lossy_input) |
|
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( |
|
1).numpy().astype(np.float32) |
|
session, onnx_model, input_names, output_names = load_model() |
|
|
|
if st.button('Conceal lossy audio!'): |
|
with st.spinner('Please wait for completion'): |
|
output = inference(re_im, session, onnx_model, input_names, output_names) |
|
|
|
st.subheader('Visualization') |
|
fig = visualize(target, lossy_input, output) |
|
st.pyplot(fig) |
|
st.success('Done!') |
|
st.text('Original audio') |
|
st.audio(target, sample_rate=sr) |
|
st.text('Lossy audio') |
|
st.audio(lossy_input, sample_rate=sr) |
|
st.text('Enhanced audio') |
|
st.audio(output, sample_rate=sr) |