Spaces:
Sleeping
Sleeping
# Imports | |
import csv | |
import sys | |
import numpy as np | |
import soundfile | |
import tensorflow as tf | |
from python.util.audio_util import audio_to_wav | |
from python.util.plt_util import plt_line, plt_mfcc, plt_mfcc2 | |
from python.util.time_util import int_to_min_sec | |
from python.util.str_util import format_float, truncate_str | |
from python.util.tensorflow_util import predict | |
# Constants | |
# MODEL_PATH = 'res/lite-model_yamnet_tflite_1.tflite' | |
MODEL_PATH = 'res/lite-model_yamnet_classification_tflite_1.tflite' | |
OUT_SAMPLE_RATE = 16000 | |
OUT_PCM = 'PCM_16' | |
CLASS_MAP_FILE = 'res/yamnet_class_map.csv' | |
DEBUG = True | |
# SNORING_TOP_N = 21 | |
SNORING_INDEX = 38 | |
IN_MODEL_SAMPLES = 15600 | |
# Methods | |
def to_ndarray(data): | |
return np.array(data) | |
def data_to_single_channel(data): | |
result = data | |
try: | |
result = data[:, 0] | |
except IndexError: | |
print("An exception occurred") | |
return result | |
def read_single_channel(audio_path): | |
data, sample_rate = soundfile.read(audio_path) | |
print(' sample_rate, audio_path: ', str(sample_rate), str(audio_path)) | |
# print(' sample_rate, len, type, shape, shape[1]: ', str(sample_rate), len(data), str(type(data)), str(data.shape), str(data.shape[1])) | |
single_channel = data_to_single_channel(data) | |
single_channel_seconds = len(single_channel) / OUT_SAMPLE_RATE | |
# print(' single_channel, shape: ', str(single_channel), str(single_channel.shape)) | |
# print(' len, seconds: ', str(len(single_channel)), str(single_channel_seconds)) | |
return single_channel, sample_rate | |
def class_names_from_csv(class_map_csv): | |
"""Read the class name definition file and return a list of strings.""" | |
if tf.is_tensor(class_map_csv): | |
class_map_csv = class_map_csv.numpy() | |
with open(class_map_csv) as csv_file: | |
reader = csv.reader(csv_file) | |
next(reader) # Skip header | |
return np.array([display_name for (_, _, display_name) in reader]) | |
def scores_to_index(scores, order): | |
means = scores.mean(axis=0) | |
return np.argsort(means, axis=0)[order] | |
def predict_waveform(idx, waveform, top_n): | |
# Download the YAMNet class map (see main YAMNet model docs) to yamnet_class_map.csv | |
# See YAMNet TF2 usage sample for class_names_from_csv() definition. | |
scores = predict(MODEL_PATH, waveform) | |
class_names = class_names_from_csv(CLASS_MAP_FILE) | |
# top_n = SNORING_TOP_N | |
top_n_res = '' | |
snoring_score = 0.0 | |
for n in range(1, top_n): | |
index = scores_to_index(scores, -n) | |
means = scores.mean(axis=0) | |
score = means[index] | |
name = class_names[index] | |
if index == SNORING_INDEX: | |
snoring_score = score | |
top_n_res += ' ' + format_float(score) + ' [' + truncate_str(name, 4) + '], ' | |
snoring_tail = ('打鼾, ' + format_float(snoring_score)) if snoring_score > 0 else '' | |
result = top_n_res + snoring_tail + '\n' | |
if DEBUG: print(top_n_res) | |
return result, snoring_score | |
def to_float32(data): | |
return np.float32(data) | |
def predict_float32(idx, data, top_n): | |
return predict_waveform(idx, to_float32(data), top_n) | |
def split_given_size(arr, size): | |
return np.split(arr, np.arange(size, len(arr), size)) | |
def predict_uri(audio_uri1, audio_uri2, top_n): | |
result = '' | |
if DEBUG: print('audio_uri1:', audio_uri1, 'audio_uri2:', audio_uri2) | |
mp3_input = audio_uri1 if audio_uri2 in (None, '') else audio_uri2 | |
wav_input = audio_to_wav(mp3_input) if not mp3_input.endswith('.mp3') == True else mp3_input | |
predict_seconds = int(str(sys.argv[2])) if len(sys.argv) > 2 else 1 | |
predict_samples = IN_MODEL_SAMPLES # OUT_SAMPLE_RATE * predict_seconds | |
single_channel, sc_sample_rate = read_single_channel(wav_input) | |
splits = split_given_size(single_channel, predict_samples) | |
result += ' sc_sample_rate: ' + str(sc_sample_rate) + '\n' | |
second_total = len(splits) * predict_seconds | |
result += (' second_total: ' + int_to_min_sec(second_total) + ', \n') | |
result += '\n' | |
snoring_scores = [] | |
for idx in range(len(splits)): | |
split = splits[idx] | |
second_start = idx * predict_seconds | |
result += (int_to_min_sec(second_start) + ', ') | |
if len(split) == predict_samples: | |
print_result, snoring_score = predict_float32(idx, split, top_n) | |
result += print_result | |
snoring_scores.append(snoring_score) | |
# plt waveform | |
waveform_line = plt_line(single_channel) | |
# plt mfcc | |
mfcc_line = plt_mfcc(single_channel, OUT_SAMPLE_RATE) | |
# plt mfcc2 | |
mfcc2_line = plt_mfcc2(wav_input, OUT_SAMPLE_RATE) | |
# plt snoring_booleans | |
snoring_booleans = list(map(lambda x: 1 if x > 0 else 0, snoring_scores)) | |
# calc snoring frequency | |
snoring_sec = len(list(filter(lambda x: 1 if x > 0 else 0, snoring_scores))) | |
snoring_frequency = snoring_sec / second_total | |
apnea_sec = second_total - snoring_sec | |
apnea_frequency = (apnea_sec / 10) / second_total | |
ahi_result = str( | |
'打鼾秒数snoring_sec=' + str(snoring_sec) + ', 暂停秒数apnea_sec=' + str(apnea_sec) + ', 总秒数second_total=' + str(second_total) | |
+ ', 打鼾频率snoring_frequency=' + str(snoring_sec) + '/' + str(second_total) + '=' + format_float(snoring_frequency) | |
+ ', 暂停频率apnea_frequency=(' + str(apnea_sec) + '/' + str(10) + ')/' + str(second_total) + '=' + format_float(apnea_frequency) | |
) | |
return waveform_line, mfcc_line, mfcc2_line, str(ahi_result), str(snoring_booleans), str(snoring_scores), str(result) | |
# sys.argv | |
if len(sys.argv) > 1 and len(sys.argv[1]) > 0: | |
res, plt = predict_uri(sys.argv[1]) | |
plt.show() | |
else: | |
print('usage: python test.py /path/to/audio_file [predict_seconds]') | |