File size: 17,930 Bytes
1e6dc54
 
40b8b4f
 
 
 
 
017e65e
40b8b4f
 
 
53f6532
40b8b4f
 
 
 
53f6532
 
40b8b4f
 
 
 
53f6532
ace06e3
763091b
40b8b4f
 
 
 
 
 
 
 
 
1e6dc54
166aa6c
40b8b4f
 
 
 
 
 
0f202d9
 
 
 
40b8b4f
166aa6c
 
40b8b4f
 
 
 
 
 
 
 
166aa6c
40b8b4f
166aa6c
 
40b8b4f
 
 
 
 
 
 
 
166aa6c
 
40b8b4f
 
 
017e65e
40b8b4f
0f202d9
 
 
 
 
40b8b4f
166aa6c
 
40b8b4f
 
 
abdf62b
40b8b4f
0f202d9
 
 
 
40b8b4f
abdf62b
 
40b8b4f
 
 
 
 
 
ace06e3
40b8b4f
ebf42ac
40b8b4f
 
 
abdf62b
40b8b4f
 
 
ace06e3
1e6dc54
40b8b4f
1e6dc54
017e65e
40b8b4f
 
017e65e
 
40b8b4f
017e65e
 
40b8b4f
017e65e
763091b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fc4b9d
763091b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deb9c39
763091b
 
 
deb9c39
 
763091b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deb9c39
763091b
 
 
 
 
 
 
 
 
2fc4b9d
763091b
 
deb9c39
40b8b4f
 
 
 
 
 
 
 
 
2fc4b9d
40b8b4f
 
 
 
 
2fc4b9d
 
40b8b4f
2fc4b9d
 
 
40b8b4f
 
 
abdf62b
40b8b4f
 
 
 
 
 
166aa6c
40b8b4f
 
 
 
 
 
 
 
abdf62b
40b8b4f
 
 
 
017e65e
166aa6c
40b8b4f
 
 
 
 
017e65e
40b8b4f
166aa6c
40b8b4f
 
 
017e65e
abdf62b
40b8b4f
 
 
abdf62b
40b8b4f
 
2fc4b9d
40b8b4f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os
import torch
import torch.nn.functional as F # Importa la API funcional de torch, incluyendo softmax
import gradio as gr # Gradio para crear interfaces web
from huggingface_hub import InferenceClient # Cliente de inferencia para acceder a modelos desde Hugging Face Hub
from model import predict_params, AudioDataset # Importaciones personalizadas: carga de modelo y procesamiento de audio
import torchaudio # Librería para procesamiento de audio

token = os.getenv("HF_TOKEN") # Obtiene el token de la API de Hugging Face desde las variables de entorno
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token) # Inicializa el cliente de Hugging Face con el modelo y el token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Verifica si hay GPU disponible, de lo contrario usa CPU
model_class, id2label_class = predict_params(
    model_path="A-POR-LOS-8000/distilhubert-finetuned-mixed-data2", # Ruta al modelo para la predicción de clases de llanto
    dataset_path="data/mixed_data", # Ruta al dataset de audio mixto
    filter_white_noise=True, # Indica que se filtrará el ruido blanco
    undersample_normal=True # Activa el submuestreo para equilibrar clases
    )
model_mon, id2label_mon = predict_params(
    model_path="A-POR-LOS-8000/distilhubert-finetuned-cry-detector", # Ruta al modelo detector de llanto
    dataset_path="data/baby_cry_detection", # Ruta al dataset de detección de llanto
    filter_white_noise=False, # No filtrar ruido blanco en este modelo
    undersample_normal=False # No submuestrear datos
    )

def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal=False):
    model.to(device) # Envía el modelo a la GPU (o CPU si no hay GPU disponible)
    model.eval() # Pone el modelo en modo de evaluación (desactiva dropout, batchnorm)
    audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal) # Carga el dataset de audio con parámetros específicos
    processed_audio = audio_dataset.preprocess_audio(audiopath) # Preprocesa el audio según la configuración del dataset
    inputs = {"input_values": processed_audio.to(device).unsqueeze(0)} # Prepara los datos para el modelo (envía a GPU y ajusta dimensiones)
    with torch.no_grad(): # Desactiva el cálculo del gradiente para ahorrar memoria
        outputs = model(**inputs) # Realiza la inferencia con el modelo
        logits = outputs.logits # Obtiene las predicciones del modelo
    return logits # Retorna los logits (valores sin procesar)

def predict(audio_path_pred):
    with torch.no_grad(): # Desactiva gradientes para la inferencia
        logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=False) # Llama a la función de inferencia
        predicted_class_ids_class = torch.argmax(logits, dim=-1).item() # Obtiene la clase predicha a partir de los logits
        label_class = id2label_class[predicted_class_ids_class] # Convierte el ID de clase en una etiqueta de texto
        label_mapping = {0: 'Cansancio/Incomodidad', 1: 'Dolor', 2: 'Hambre', 3: 'Problemas para respirar'} # Mapea las etiquetas
        label_class = label_mapping.get(predicted_class_ids_class, label_class) # Si hay una etiqueta personalizada, la usa
    return f"""
        <div style='text-align: center; font-size: 1.5em'>
            <span style='display: inline-block; min-width: 300px;'>{label_class}</span>
        </div>
    """ # Retorna el resultado formateado para mostrar en la interfaz

def predict_stream(audio_path_stream):
    with torch.no_grad(): # Desactiva gradientes durante la inferencia
        logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False, undersample_normal=False) # Llama al modelo de detección de llanto
        probabilities = F.softmax(logits, dim=-1) # Aplica softmax para convertir los logits en probabilidades
        crying_probabilities = probabilities[:, 1] # Obtiene las probabilidades asociadas al llanto
        avg_crying_probability = crying_probabilities.mean()*100 # Calcula la probabilidad promedio de llanto
        if avg_crying_probability < 15: # Si la probabilidad de llanto es menor a un 15%, se predice la razón
            label_class = predict(audio_path_stream) # Llama a la predicción para determinar la razón del llanto
            return f"Está llorando por: {label_class}" # Retorna el resultado indicando por qué llora
        else:
            return "No está llorando" # Si la probabilidad es mayor, indica que no está llorando

def decibelios(audio_path_stream):
    waveform, _ = torchaudio.load(audio_path_stream) # Carga el audio y su forma de onda
    rms = torch.sqrt(torch.mean(torch.square(waveform))) # Calcula el valor RMS del audio
    db_level = 20 * torch.log10(rms + 1e-6).item() # Convierte el RMS en decibelios (añade un pequeño valor para evitar log(0))
    min_db = -80 # Nivel mínimo de decibelios esperado
    max_db = 0 # Nivel máximo de decibelios esperado
    scaled_db_level = (db_level - min_db) / (max_db - min_db) # Escala el nivel de decibelios a un rango entre 0 y 1
    normalized_db_level = scaled_db_level * 100 # Escala el nivel de decibelios a un porcentaje
    return normalized_db_level # Retorna el nivel de decibelios normalizado

def mostrar_decibelios(audio_path_stream, visual_threshold):
    db_level = decibelios(audio_path_stream)# Obtiene el nivel de decibelios del audio
    if db_level > visual_threshold: # Si el nivel de decibelios supera el umbral visual
        status = "Prediciendo..." # Cambia el estado a "Prediciendo"
    else:
        status = "Esperando..." # Si no supera el umbral, indica que está "Esperando"
    return f"""
        <div style='text-align: center; font-size: 1.5em'>
            <span>{status}</span>
            <span style='display: inline-block; min-width: 120px;'>Decibelios: {db_level:.2f}</span>
        </div>
    """ # Retorna una cadena HTML con el estado y el nivel de decibelios

def predict_stream_decib(audio_path_stream, visual_threshold):
    db_level = decibelios(audio_path_stream) # Calcula el nivel de decibelios
    if db_level > visual_threshold: # Si supera el umbral, hace una predicción
        prediction = display_prediction_stream(audio_path_stream) # Llama a la función de predicción
    else:
        prediction = "" # Si no supera el umbral, no muestra predicción
    return f"""
        <div style='text-align: center; font-size: 1.5em; min-height: 2em;'>
            <span style='display: inline-block; min-width: 300px;'>{prediction}</span>
        </div>
    """ # Retorna el resultado o nada si no supera el umbral

def chatbot_config(message, history: list[tuple[str, str]]):
    system_message = "You are a Chatbot specialized in baby health and care." # Mensaje inicial del chatbot
    max_tokens = 512 # Máximo de tokens para la respuesta
    temperature = 0.5 # Controla la aleatoriedad de las respuestas
    top_p = 0.95 # Top-p sampling para filtrar palabras
    messages = [{"role": "system", "content": system_message}] # Configura el mensaje del sistema para el chatbot
    for val in history: # Añade el historial de la conversación al mensaje
        if val[0]:
            messages.append({"role": "user", "content": val[0]}) # Añade los mensajes del usuario
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]}) # Añade las respuestas del asistente
    messages.append({"role": "user", "content": message}) # Añade el mensaje actual del usuario
    response = "" # Inicializa la variable de respuesta
    for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
        token = message_response.choices[0].delta.content # Obtiene el contenido del mensaje generado por el modelo
        response += token # Acumula el contenido generado en la respuesta final
    return response # Retorna la respuesta generada por el modelo

def cambiar_pestaña():
    return gr.update(visible=False), gr.update(visible=True) # Esta función cambia la visibilidad de las pestañas en la interfaz

def display_prediction(audio, prediction_func):
    prediction = prediction_func(audio) # Llama a la función de predicción para obtener el resultado
    return f"<h3 style='text-align: center; font-size: 1.5em;'>{prediction}</h3>" # Retorna el resultado formateado en HTML

def display_prediction_wrapper(audio):
    return display_prediction(audio, predict) # Envuelve la función de predicción "predict" en la función "display_prediction"

def display_prediction_stream(audio):
    return display_prediction(audio, predict_stream) # Envuelve la función de predicción "predict_stream" en la función "display_prediction"

my_theme = gr.themes.Soft(
    primary_hue="emerald",
    secondary_hue="green",
    neutral_hue="slate",
    text_size="sm",
    spacing_size="sm",
    font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
    font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
    ).set(
    body_background_fill='*neutral_50',
    body_text_color='*neutral_600',
    body_text_size='*text_sm',
    embed_radius='*radius_md',
    shadow_drop='*shadow_spread',
    shadow_spread='*button_shadow_active'
    )

with gr.Blocks(theme=my_theme, fill_height=True, fill_width=True) as demo:
    with gr.Column(visible=True) as inicial:    
        gr.HTML(
            """
            <style>
            @import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
            @import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
            
            h1 {
                font-family: 'Lobster', cursive;
                font-size: 5em !important;
                text-align: center;
                margin: 0;
            }
            
            .gr-button {
                background-color: #4CAF50 !important; 
                color: white !important; 
                border: none;
                padding: 25px 50px; 
                text-align: center;
                text-decoration: none;
                display: inline-block;
                font-family: 'Lobster', cursive; 
                font-size: 2em !important; 
                margin: 4px 2px;
                cursor: pointer;
                border-radius: 12px;
            }
            
            .gr-button:hover {
                background-color: #45a049; 
            }
            h2 {
                font-family: 'Lobster', cursive;
                font-size: 3em !important;
                text-align: center;
                margin: 0;
            }
            p.slogan, h4, p, h3 {
                font-family: 'Roboto', sans-serif;
                text-align: center;
            }
            </style>
            <h1>Iremia</h1>
            <h4 style='text-align: center; font-size: 1.5em'>El mejor aliado para el bienestar de tu bebé</h4>
            """
        )
        gr.Markdown(
            "<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
            "<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
            "<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
            "<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
            "<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
            "<p style='text-align: left'>Chatbot: Pregunta a nuestro asistente que te ayudará con cualquier duda que tengas sobre el cuidado de tu bebé.</p>"
            "<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial somos capaces de predecir por qué tu bebé está llorando.</p>"
            "<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no; y si está llorando, predice automáticamente la causa. Dándote la tranquilidad de saber siempre qué pasa con tu pequeño, ahorrándote tiempo y horas de sueño.</p>"
        )
        boton_inicial = gr.Button("¡Prueba nuestros modelos!")
    with gr.Column(visible=False) as chatbot: # Columna para la pestaña del chatbot
        gr.Markdown("<h2>Asistente</h2>") # Título de la pestaña del chatbot
        gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Pregunta a nuestro asistente cualquier duda que tengas sobre el cuidado de tu bebé</h4>")  # Descripción de la pestaña del chatbot
        gr.ChatInterface(
            chatbot_config, # Función de configuración del chatbot
            theme=my_theme, # Tema personalizado para la interfaz
            retry_btn=None, # Botón de reintentar desactivado
            undo_btn=None, # Botón de deshacer desactivado
            clear_btn="Limpiar 🗑️", # Botón de limpiar mensajes
            submit_btn="Enviar", # Botón de enviar mensaje
            autofocus=True, # Enfocar automáticamente el campo de entrada de texto
            fill_height=True, # Rellenar el espacio verticalmente
        )
        with gr.Row(): # Fila para los botones de cambio de pestaña
            with gr.Column(): # Columna para el botón del predictor
                gr.Markdown("<h2>Predictor</h2>") # Título de la pestaña del chatbot
                boton_predictor = gr.Button("Probar predictor") # Botón para cambiar a la pestaña del predictor
            with gr.Column(): # Columna para el botón del monitor
                gr.Markdown("<h2>Monitor</h2>") # Título de la pestaña del chatbot
                boton_monitor = gr.Button("Probar monitor") # Botón para cambiar a la pestaña del monitor
        boton_volver_inicio = gr.Button("Volver al inicio") # Botón para volver a la pestaña inicial
    with gr.Column(visible=False) as pag_predictor: # Columna para la pestaña del predictor
        gr.Markdown("<h2>Predictor</h2>") # Título de la pestaña del predictor
        gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Descubre por qué tu bebé está llorando</h4>") # Descripción de la pestaña del predictor
        audio_input = gr.Audio(
            min_length=1.0, # Duración mínima del audio requerida
            format="wav", # Formato de audio admitido
            label="Baby recorder", # Etiqueta del campo de entrada de audio
            type="filepath", # Tipo de entrada de audio (archivo)
        )
        prediction_output = gr.Markdown() # Salida para mostrar la predicción
        gr.Button("¿Por qué llora?").click(
            display_prediction_wrapper, # Función de predicción para el botón
            inputs=audio_input, # Entrada de audio para la función de predicción
            outputs=gr.Markdown() # Salida para mostrar la predicción
        )
        gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot]) # Botón para volver a la pestaña del chatbot
    with gr.Column(visible=False) as pag_monitor: # Columna para la pestaña del monitor
        gr.Markdown("<h2>Monitor</h2>") # Título de la pestaña del monitor
        gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Detecta en tiempo real si tu bebé está llorando y por qué</h4>")  # Descripción de la pestaña del monitor
        audio_stream = gr.Audio(
            format="wav", # Formato de audio admitido
            label="Baby recorder", # Etiqueta del campo de entrada de audio
            type="filepath", # Tipo de entrada de audio (archivo)
            streaming=True # Habilitar la transmisión de audio en tiempo real
        )
        threshold_db = gr.Slider(
            minimum=0, # Valor mínimo del umbral de ruido
            maximum=120, # Valor máximo del umbral de ruido
            step=1, # Paso del umbral de ruido
            value=20, # Valor inicial del umbral de ruido
            label="Umbral de ruido para activar la predicción:" # Etiqueta del control deslizante del umbral de ruido
        )
        volver = gr.Button("Volver") # Botón para volver a la pestaña del chatbot
        audio_stream.stream(
            mostrar_decibelios, # Función para mostrar el nivel de decibelios
            inputs=[audio_stream, threshold_db], # Entradas para la función de mostrar decibelios
            outputs=gr.HTML() # Salida para mostrar el nivel de decibelios
        )
        audio_stream.stream(
            predict_stream_decib, # Función para realizar la predicción en tiempo real
            inputs=[audio_stream, threshold_db], # Entradas para la función de predicción en tiempo real
            outputs=gr.HTML() # Salida para mostrar la predicción en tiempo real
        )
        volver.click(cambiar_pestaña, outputs=[pag_monitor, chatbot]) # Botón para volver a la pestaña del chatbot
    boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot]) # Botón para cambiar a la pestaña inicial
    boton_volver_inicio.click(cambiar_pestaña, outputs=[chatbot, inicial]) # Botón para volver a la pestaña inicial desde el chatbot
    boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor]) # Botón para cambiar a la pestaña del predictor
    boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor]) # Botón para cambiar a la pestaña del monitor
demo.launch(share=True) # Lanzar la interfaz gráfica