import streamlit as st import librosa import librosa.display from config import CONFIG import torch from dataset import MaskGenerator import onnxruntime, onnx import matplotlib.pyplot as plt import numpy as np from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas @st.cache_resource def load_model(): path = 'lightning_logs/version_0/checkpoints/frn.onnx' onnx_model = onnx.load(path) options = onnxruntime.SessionOptions() options.intra_op_num_threads = 2 options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL session = onnxruntime.InferenceSession(path, options) input_names = [x.name for x in session.get_inputs()] output_names = [x.name for x in session.get_outputs()] return session, onnx_model, input_names, output_names def inference(re_im, session, onnx_model, input_names, output_names): inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], dtype=np.float32) for i, _input in enumerate(onnx_model.graph.input) } output_audio = [] for t in range(re_im.shape[0]): inputs[input_names[0]] = re_im[t] out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) inputs[input_names[1]] = prev_mag inputs[input_names[2]] = predictor_state inputs[input_names[3]] = mlp_state output_audio.append(out) output_audio = torch.tensor(np.concatenate(output_audio, 0)) output_audio = output_audio.permute(1, 0, 2).contiguous() output_audio = torch.view_as_complex(output_audio) output_audio = torch.istft(output_audio, window, stride, window=hann) return output_audio.numpy() def visualize(hr, lr, recon): sr = CONFIG.DATA.sr window_size = 1024 window = np.hanning(window_size) stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window) stft_hr = 2 * np.abs(stft_hr) / np.sum(window) stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window) stft_lr = 2 * np.abs(stft_lr) / np.sum(window) stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window) stft_recon = 2 * np.abs(stft_recon) / np.sum(window) fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10)) ax1.title.set_text('Target signal') ax2.title.set_text('Lossy signal') ax3.title.set_text('Enhanced signal') canvas = FigureCanvas(fig) p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr) p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr) return fig packet_size = CONFIG.DATA.EVAL.packet_size window = CONFIG.DATA.window_size stride = CONFIG.DATA.stride title = 'Packet Loss Concealment' st.set_page_config(page_title=title, page_icon=":sound:") st.title(title) uploaded_file = st.file_uploader("Upload your audio file (.wav)") is_file_uploaded = uploaded_file is not None if not is_file_uploaded: uploaded_file = 'sample.wav' target, sr = librosa.load(uploaded_file, sr=48000) target = target[:packet_size * (len(target) // packet_size)] st.subheader('Original audio') st.audio(uploaded_file) st.subheader('Choose loss packet percentage') loss_percent = st.radio('Loss percentage', ['10%', '20%', '30%', '40%']) loss_percent = float(loss_percent[:-1])/100 mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)]) lossy_input = target.copy().reshape(-1, packet_size) mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis] lossy_input *= mask lossy_input = lossy_input.reshape(-1) hann = torch.sqrt(torch.hann_window(window)) lossy_input_tensor = torch.tensor(lossy_input) re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( 1).numpy().astype(np.float32) session, onnx_model, input_names, output_names = load_model() if st.button('Conceal lossy audio!'): with st.spinner('Please wait for completion'): output = inference(re_im, session, onnx_model, input_names, output_names) st.subheader('Visualization') fig = visualize(target, lossy_input, output) st.pyplot(fig) st.success('Done!') st.text('Original audio') st.audio(target, sample_rate=sr) st.text('Lossy audio') st.audio(lossy_input, sample_rate=sr) st.text('Enhanced audio') st.audio(output, sample_rate=sr)