Testing / app.py
Robertomarting's picture
Update app.py
d88f91c verified
import torchaudio
import gradio as gr
import soundfile as sf
import tempfile
import os
import io
import librosa
import numpy as np
import pandas as pd
from transformers import ASTFeatureExtractor, AutoModelForAudioClassification, Trainer, Wav2Vec2FeatureExtractor, HubertForSequenceClassification, pipeline
from datasets import Dataset, DatasetDict
import torch.nn.functional as F
import torch
from collections import Counter
from scipy.stats import kurtosis
from huggingface_hub import InferenceClient
import os
import time
'''
Predictor
'''
#Obtenemos el token para traernos el modelo:
access_token_mod_1 = os.getenv('HF_Access_Personal')
#Cargamos procesador y modelo:
processor = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = AutoModelForAudioClassification.from_pretrained("Robertomarting/tmp_trainer",token=access_token_mod_1)
#Definimos una función para eliminar segmentos de audio con un determinado porcentaje de ruido blanco:
def is_white_noise(audio, threshold=0.75):
kurt = kurtosis(audio)
return np.abs(kurt) < 0.1 and np.mean(np.abs(audio)) < threshold
#Función de procesado de audio, permite particionar en fragmentos de 1 segundo, hacer un trim, volverlo mono si está en estéreo, resamplearlo
#al sampling rate que admite el modelo, etc.
def process_audio(audio_tuple, target_sr=16000, target_duration=1.0):
data = []
target_length = int(target_sr * target_duration)
wav_buffer = io.BytesIO()
sf.write(wav_buffer, audio_tuple[1], audio_tuple[0], format='wav')
wav_buffer.seek(0)
audio_data, sample_rate = sf.read(wav_buffer)
audio_data = audio_data.astype(np.float32)
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
if sample_rate != target_sr:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=target_sr)
audio_data, _ = librosa.effects.trim(audio_data)
if len(audio_data) > target_length:
for i in range(0, len(audio_data), target_length):
segment = audio_data[i:i + target_length]
if len(segment) == target_length and not is_white_noise(segment):
data.append(segment)
else:
if not is_white_noise(audio_data):
data.append(audio_data)
return data
#Se aplica al extractor de características del modelo:
def preprocess_audio(audio_segments):
inputs = processor(
audio_segments,
padding=True,
sampling_rate=processor.sampling_rate,
max_length=int(processor.sampling_rate * 1),
truncation=True,
return_tensors="pt"
)
return inputs
#Se hace la predicción para cada audio:
def predict_audio(audio):
audio_segments = process_audio(audio)
inputs = preprocess_audio(audio_segments)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1).numpy()
predicted_classes = probabilities.argmax(axis=1)
most_common_predicted_label = Counter(predicted_classes).most_common(1)[0][0]
replace_dict = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
most_common_predicted_label = replace_dict[most_common_predicted_label]
return most_common_predicted_label
def display_prediction(audio):
prediction = predict_audio(audio)
return f"<h3 style='text-align: center; font-size: 1.5em;'>Tu bebé llora por: {prediction}</h3>"
def clear_audio_input(audio):
return ""
'''
Monitor
'''
def process_audio_monitor(audio_tuple, target_sr=16000, target_duration=1.0):
data = []
target_length = int(target_sr * target_duration)
wav_buffer = io.BytesIO()
sf.write(wav_buffer, audio_tuple[1], audio_tuple[0], format='wav')
wav_buffer.seek(0)
audio_data, sample_rate = sf.read(wav_buffer)
audio_data = audio_data.astype(np.float32)
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
if sample_rate != target_sr:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=target_sr)
audio_data, _ = librosa.effects.trim(audio_data)
if len(audio_data) > target_length:
for i in range(0, len(audio_data), target_length):
segment = audio_data[i:i + target_length]
if len(segment) == target_length:
data.append(segment)
else:
data.append(audio_data)
return data
#Sacamos extractor de características:
FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained("ntu-spml/distilhubert")
#Y nuestro modelo:
model_monitor = HubertForSequenceClassification.from_pretrained("A-POR-LOS-8000/distilhubert-finetuned-cry-detector")
#Calculamos decibelios de lo que llega al gradio:
def compute_db(audio_data):
rms = np.sqrt(np.mean(np.square(audio_data)))
db = 20 * np.log10(rms + 1e-6)
db = round(db,2)
return db
#Función de extracción de características para el monitor:
def preprocess_audio_monitor(audio_segments):
inputs = FEATURE_EXTRACTOR(
audio_segments,
padding=True,
sampling_rate=16000,
max_length=int(16000*1),
return_tensors="pt"
)
return inputs
#Función de predicción en streaming:
def predict_audio_stream(audio_data, sample_rate):
audio_segments = process_audio_monitor(audio_data)
inputs = preprocess_audio_monitor(audio_segments)
with torch.no_grad():
outputs = model_monitor(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1).numpy()
crying_probabilities = probabilities[:, 1]
avg_crying_probability = crying_probabilities.mean()
if avg_crying_probability < 0.25:
inputs = preprocess_audio(audio_segments)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1).numpy()
predicted_classes = probabilities.argmax(axis=1)
most_common_predicted_label = Counter(predicted_classes).most_common(1)[0][0]
replace_dict = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
most_common_predicted_label = replace_dict[most_common_predicted_label]
return "Está llorando", 1-avg_crying_probability, most_common_predicted_label
else:
return "No está llorando", 1-avg_crying_probability, ""
#Función que se encarga de indicarle al usuario si se ha pasado el umbral:
def update_status_to_predicting(audio, visual_threshold):
sample_rate, audio_data = audio
audio_data = np.array(audio_data, dtype=np.float32)
db_level = compute_db(audio_data)
db_level = round(db_level, 2)
if db_level < visual_threshold:
return f"Esperando... Decibelios: {db_level}"
else:
return f"Prediciendo... Decibelios: {db_level}"
#Función que realiza la predicción
def capture_and_predict(audio,visual_threshold, sample_rate=16000, duration=5):
sample_rate, audio_data = audio
audio_data = np.array(audio_data, dtype=np.float32)
db_level = compute_db(audio_data)
if db_level > visual_threshold:
max_samples = sample_rate * duration
audio_data = audio[:max_samples]
if len(audio_data) != 0:
result, probabilidad, result_2 = predict_audio_stream(audio_data, sample_rate)
if result == "Está llorando":
return f"{result}, por {result_2}"
else:
return "No está llorando"
'''
Asistente
'''
#Traemos el token:
access_token = os.getenv('HF_ACCESS_TOKEN')
#Generamos el cliente:
client = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407", token=access_token)
#Generamos una función de respuesta:
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
'''
Interfaz
'''
#Generamos un theme con parámetros personalizados:
my_theme = gr.themes.Soft(
primary_hue="emerald",
secondary_hue="green",
neutral_hue="slate",
text_size="sm",
spacing_size="sm",
font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
).set(
body_background_fill='*neutral_50',
body_text_color='*neutral_600',
body_text_size='*text_sm',
embed_radius='*radius_md',
shadow_drop='*shadow_spread',
shadow_spread='*button_shadow_active'
)
#Función para mostrar la página del Predictor
def mostrar_pagina_1():
return gr.update(visible=False), gr.update(visible=True)
#Función para regresar a la pantalla inicial
def redirigir_a_pantalla_inicial():
return gr.update(visible=True), gr.update(visible=False)
#Generamos el gradio:
with gr.Blocks(theme = my_theme) as demo:
with gr.Column() as pantalla_inicial:
gr.HTML(
"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
@import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
h1 {
font-family: 'Lobster', cursive;
font-size: 5em !important;
text-align: center;
margin: 0;
}
h2 {
font-family: 'Lobster', cursive;
font-size: 3em !important;
text-align: center;
margin: 0;
}
h3 {
font-family: 'Roboto', sans-serif;
text-align: center;
font-size: 1.5em !important;
}
p.slogan, h4, p {
font-family: 'Roboto', sans-serif;
text-align: center;
}
</style>
<h1>Iremia</h1>
<h4 style='text-align: center; font-size: 1.5em'>El mejor aliado para el bienestar de tu bebé</h4>
"""
)
gr.Markdown("<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>")
gr.Markdown("<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>")
gr.Markdown("<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>")
gr.Markdown("<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>")
gr.Markdown("<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>")
gr.Markdown("<p style='text-align: left'>Iremia ofrece dos funcionalidades muy interesantes:</p>")
gr.Markdown("<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial, somos capaces de predecir por qué tu hijo de menos de 2 años está llorando. Además, tendrás acceso a un asistente personal para consultar cualquier duda que tengas sobre el cuidado de tu pequeño.</p>")
gr.Markdown("<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no, y si está llorando, predice automáticamente la causa, lo cual te brindará la tranquilidad de saber siempre qué pasa con tu pequeño y te ahorrará tiempo y muchas horas de sueño.</p>")
gr.Markdown("<p style='text-align: left'>Asistente: Contamos con un chatbot que podrá responder cualquier duda que tengas sobre el cuidado de tu bebé.</p>")
with gr.Row():
with gr.Column():
boton_pagina_1 = gr.Button("¡Prueba nuestros modelos!")
gr.Markdown("<p>Descubre por qué llora tu bebé, prueba nuestro monitor inteligente y resuelve dudas sobre el cuidado de tu pequeño con nuestras herramientas de última tecnología</p>")
with gr.Column(visible=False) as pagina_1:
with gr.Row():
with gr.Column():
gr.Markdown("<h2>Predictor</h2>")
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Descubre por qué tu bebé está llorando</h4>")
audio_input = gr.Audio(type="numpy", label="Baby recorder")
classify_btn = gr.Button("¿Por qué llora?")
classification_output = gr.Markdown()
classify_btn.click(display_prediction, inputs=audio_input, outputs=classification_output)
audio_input.change(fn=clear_audio_input, inputs=audio_input, outputs=classification_output)
with gr.Column():
gr.Markdown("<h2>Monitor</h2>")
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Detecta en tiempo real si tu bebé está llorando</h4>")
audio_stream = gr.Audio(sources=["microphone"], streaming=True)
threshold_db = gr.Slider(minimum=0, maximum=200, step=1, value=50, label="Umbral de dB para activar la predicción")
status_label = gr.Textbox(label="Estado")
prediction_label = gr.Textbox(label="Tu bebé:")
audio_stream.stream(
fn=update_status_to_predicting,
inputs=[audio_stream, threshold_db],
outputs=status_label
)
# Captura el audio y realiza la predicción si se supera el umbral
audio_stream.stream(
fn=capture_and_predict,
inputs=[audio_stream,threshold_db],
outputs=prediction_label
)
with gr.Row():
with gr.Column():
gr.Markdown("<h2>Asistente</h2>")
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Pregunta a nuestro asistente cualquier duda que tengas sobre tu pequeño</h4>")
system_message = "Eres un chatbot especializado en el cuidado y la salud de los bebés. Estás dispuesto a ayudar amablemente a cualquier padre que tenga dudas o preocupaciones sobre su hijo o hija."
max_tokens = 512
temperature = 0.7
top_p = 0.95
chatbot = gr.ChatInterface(
respond,
additional_inputs=[
gr.State(value=system_message),
gr.State(value=max_tokens),
gr.State(value=temperature),
gr.State(value=top_p)
],
)
gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
boton_volver_inicio_1 = gr.Button("Volver a la pantalla inicial")
boton_volver_inicio_1.click(redirigir_a_pantalla_inicial, inputs=None, outputs=[pantalla_inicial, pagina_1])
boton_pagina_1.click(mostrar_pagina_1, inputs=None, outputs=[pantalla_inicial, pagina_1])
demo.launch()