Spaces:
Running
Running
File size: 17,930 Bytes
1e6dc54 40b8b4f 017e65e 40b8b4f 53f6532 40b8b4f 53f6532 40b8b4f 53f6532 ace06e3 763091b 40b8b4f 1e6dc54 166aa6c 40b8b4f 0f202d9 40b8b4f 166aa6c 40b8b4f 166aa6c 40b8b4f 166aa6c 40b8b4f 166aa6c 40b8b4f 017e65e 40b8b4f 0f202d9 40b8b4f 166aa6c 40b8b4f abdf62b 40b8b4f 0f202d9 40b8b4f abdf62b 40b8b4f ace06e3 40b8b4f ebf42ac 40b8b4f abdf62b 40b8b4f ace06e3 1e6dc54 40b8b4f 1e6dc54 017e65e 40b8b4f 017e65e 40b8b4f 017e65e 40b8b4f 017e65e 763091b 2fc4b9d 763091b deb9c39 763091b deb9c39 763091b deb9c39 763091b 2fc4b9d 763091b deb9c39 40b8b4f 2fc4b9d 40b8b4f 2fc4b9d 40b8b4f 2fc4b9d 40b8b4f abdf62b 40b8b4f 166aa6c 40b8b4f abdf62b 40b8b4f 017e65e 166aa6c 40b8b4f 017e65e 40b8b4f 166aa6c 40b8b4f 017e65e abdf62b 40b8b4f abdf62b 40b8b4f 2fc4b9d 40b8b4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
import os
import torch
import torch.nn.functional as F # Importa la API funcional de torch, incluyendo softmax
import gradio as gr # Gradio para crear interfaces web
from huggingface_hub import InferenceClient # Cliente de inferencia para acceder a modelos desde Hugging Face Hub
from model import predict_params, AudioDataset # Importaciones personalizadas: carga de modelo y procesamiento de audio
import torchaudio # Librería para procesamiento de audio
token = os.getenv("HF_TOKEN") # Obtiene el token de la API de Hugging Face desde las variables de entorno
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token) # Inicializa el cliente de Hugging Face con el modelo y el token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Verifica si hay GPU disponible, de lo contrario usa CPU
model_class, id2label_class = predict_params(
model_path="A-POR-LOS-8000/distilhubert-finetuned-mixed-data2", # Ruta al modelo para la predicción de clases de llanto
dataset_path="data/mixed_data", # Ruta al dataset de audio mixto
filter_white_noise=True, # Indica que se filtrará el ruido blanco
undersample_normal=True # Activa el submuestreo para equilibrar clases
)
model_mon, id2label_mon = predict_params(
model_path="A-POR-LOS-8000/distilhubert-finetuned-cry-detector", # Ruta al modelo detector de llanto
dataset_path="data/baby_cry_detection", # Ruta al dataset de detección de llanto
filter_white_noise=False, # No filtrar ruido blanco en este modelo
undersample_normal=False # No submuestrear datos
)
def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal=False):
model.to(device) # Envía el modelo a la GPU (o CPU si no hay GPU disponible)
model.eval() # Pone el modelo en modo de evaluación (desactiva dropout, batchnorm)
audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal) # Carga el dataset de audio con parámetros específicos
processed_audio = audio_dataset.preprocess_audio(audiopath) # Preprocesa el audio según la configuración del dataset
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)} # Prepara los datos para el modelo (envía a GPU y ajusta dimensiones)
with torch.no_grad(): # Desactiva el cálculo del gradiente para ahorrar memoria
outputs = model(**inputs) # Realiza la inferencia con el modelo
logits = outputs.logits # Obtiene las predicciones del modelo
return logits # Retorna los logits (valores sin procesar)
def predict(audio_path_pred):
with torch.no_grad(): # Desactiva gradientes para la inferencia
logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=False) # Llama a la función de inferencia
predicted_class_ids_class = torch.argmax(logits, dim=-1).item() # Obtiene la clase predicha a partir de los logits
label_class = id2label_class[predicted_class_ids_class] # Convierte el ID de clase en una etiqueta de texto
label_mapping = {0: 'Cansancio/Incomodidad', 1: 'Dolor', 2: 'Hambre', 3: 'Problemas para respirar'} # Mapea las etiquetas
label_class = label_mapping.get(predicted_class_ids_class, label_class) # Si hay una etiqueta personalizada, la usa
return f"""
<div style='text-align: center; font-size: 1.5em'>
<span style='display: inline-block; min-width: 300px;'>{label_class}</span>
</div>
""" # Retorna el resultado formateado para mostrar en la interfaz
def predict_stream(audio_path_stream):
with torch.no_grad(): # Desactiva gradientes durante la inferencia
logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False, undersample_normal=False) # Llama al modelo de detección de llanto
probabilities = F.softmax(logits, dim=-1) # Aplica softmax para convertir los logits en probabilidades
crying_probabilities = probabilities[:, 1] # Obtiene las probabilidades asociadas al llanto
avg_crying_probability = crying_probabilities.mean()*100 # Calcula la probabilidad promedio de llanto
if avg_crying_probability < 15: # Si la probabilidad de llanto es menor a un 15%, se predice la razón
label_class = predict(audio_path_stream) # Llama a la predicción para determinar la razón del llanto
return f"Está llorando por: {label_class}" # Retorna el resultado indicando por qué llora
else:
return "No está llorando" # Si la probabilidad es mayor, indica que no está llorando
def decibelios(audio_path_stream):
waveform, _ = torchaudio.load(audio_path_stream) # Carga el audio y su forma de onda
rms = torch.sqrt(torch.mean(torch.square(waveform))) # Calcula el valor RMS del audio
db_level = 20 * torch.log10(rms + 1e-6).item() # Convierte el RMS en decibelios (añade un pequeño valor para evitar log(0))
min_db = -80 # Nivel mínimo de decibelios esperado
max_db = 0 # Nivel máximo de decibelios esperado
scaled_db_level = (db_level - min_db) / (max_db - min_db) # Escala el nivel de decibelios a un rango entre 0 y 1
normalized_db_level = scaled_db_level * 100 # Escala el nivel de decibelios a un porcentaje
return normalized_db_level # Retorna el nivel de decibelios normalizado
def mostrar_decibelios(audio_path_stream, visual_threshold):
db_level = decibelios(audio_path_stream)# Obtiene el nivel de decibelios del audio
if db_level > visual_threshold: # Si el nivel de decibelios supera el umbral visual
status = "Prediciendo..." # Cambia el estado a "Prediciendo"
else:
status = "Esperando..." # Si no supera el umbral, indica que está "Esperando"
return f"""
<div style='text-align: center; font-size: 1.5em'>
<span>{status}</span>
<span style='display: inline-block; min-width: 120px;'>Decibelios: {db_level:.2f}</span>
</div>
""" # Retorna una cadena HTML con el estado y el nivel de decibelios
def predict_stream_decib(audio_path_stream, visual_threshold):
db_level = decibelios(audio_path_stream) # Calcula el nivel de decibelios
if db_level > visual_threshold: # Si supera el umbral, hace una predicción
prediction = display_prediction_stream(audio_path_stream) # Llama a la función de predicción
else:
prediction = "" # Si no supera el umbral, no muestra predicción
return f"""
<div style='text-align: center; font-size: 1.5em; min-height: 2em;'>
<span style='display: inline-block; min-width: 300px;'>{prediction}</span>
</div>
""" # Retorna el resultado o nada si no supera el umbral
def chatbot_config(message, history: list[tuple[str, str]]):
system_message = "You are a Chatbot specialized in baby health and care." # Mensaje inicial del chatbot
max_tokens = 512 # Máximo de tokens para la respuesta
temperature = 0.5 # Controla la aleatoriedad de las respuestas
top_p = 0.95 # Top-p sampling para filtrar palabras
messages = [{"role": "system", "content": system_message}] # Configura el mensaje del sistema para el chatbot
for val in history: # Añade el historial de la conversación al mensaje
if val[0]:
messages.append({"role": "user", "content": val[0]}) # Añade los mensajes del usuario
if val[1]:
messages.append({"role": "assistant", "content": val[1]}) # Añade las respuestas del asistente
messages.append({"role": "user", "content": message}) # Añade el mensaje actual del usuario
response = "" # Inicializa la variable de respuesta
for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
token = message_response.choices[0].delta.content # Obtiene el contenido del mensaje generado por el modelo
response += token # Acumula el contenido generado en la respuesta final
return response # Retorna la respuesta generada por el modelo
def cambiar_pestaña():
return gr.update(visible=False), gr.update(visible=True) # Esta función cambia la visibilidad de las pestañas en la interfaz
def display_prediction(audio, prediction_func):
prediction = prediction_func(audio) # Llama a la función de predicción para obtener el resultado
return f"<h3 style='text-align: center; font-size: 1.5em;'>{prediction}</h3>" # Retorna el resultado formateado en HTML
def display_prediction_wrapper(audio):
return display_prediction(audio, predict) # Envuelve la función de predicción "predict" en la función "display_prediction"
def display_prediction_stream(audio):
return display_prediction(audio, predict_stream) # Envuelve la función de predicción "predict_stream" en la función "display_prediction"
my_theme = gr.themes.Soft(
primary_hue="emerald",
secondary_hue="green",
neutral_hue="slate",
text_size="sm",
spacing_size="sm",
font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
).set(
body_background_fill='*neutral_50',
body_text_color='*neutral_600',
body_text_size='*text_sm',
embed_radius='*radius_md',
shadow_drop='*shadow_spread',
shadow_spread='*button_shadow_active'
)
with gr.Blocks(theme=my_theme, fill_height=True, fill_width=True) as demo:
with gr.Column(visible=True) as inicial:
gr.HTML(
"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
@import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
h1 {
font-family: 'Lobster', cursive;
font-size: 5em !important;
text-align: center;
margin: 0;
}
.gr-button {
background-color: #4CAF50 !important;
color: white !important;
border: none;
padding: 25px 50px;
text-align: center;
text-decoration: none;
display: inline-block;
font-family: 'Lobster', cursive;
font-size: 2em !important;
margin: 4px 2px;
cursor: pointer;
border-radius: 12px;
}
.gr-button:hover {
background-color: #45a049;
}
h2 {
font-family: 'Lobster', cursive;
font-size: 3em !important;
text-align: center;
margin: 0;
}
p.slogan, h4, p, h3 {
font-family: 'Roboto', sans-serif;
text-align: center;
}
</style>
<h1>Iremia</h1>
<h4 style='text-align: center; font-size: 1.5em'>El mejor aliado para el bienestar de tu bebé</h4>
"""
)
gr.Markdown(
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
"<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
"<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
"<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
"<p style='text-align: left'>Chatbot: Pregunta a nuestro asistente que te ayudará con cualquier duda que tengas sobre el cuidado de tu bebé.</p>"
"<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial somos capaces de predecir por qué tu bebé está llorando.</p>"
"<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no; y si está llorando, predice automáticamente la causa. Dándote la tranquilidad de saber siempre qué pasa con tu pequeño, ahorrándote tiempo y horas de sueño.</p>"
)
boton_inicial = gr.Button("¡Prueba nuestros modelos!")
with gr.Column(visible=False) as chatbot: # Columna para la pestaña del chatbot
gr.Markdown("<h2>Asistente</h2>") # Título de la pestaña del chatbot
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Pregunta a nuestro asistente cualquier duda que tengas sobre el cuidado de tu bebé</h4>") # Descripción de la pestaña del chatbot
gr.ChatInterface(
chatbot_config, # Función de configuración del chatbot
theme=my_theme, # Tema personalizado para la interfaz
retry_btn=None, # Botón de reintentar desactivado
undo_btn=None, # Botón de deshacer desactivado
clear_btn="Limpiar 🗑️", # Botón de limpiar mensajes
submit_btn="Enviar", # Botón de enviar mensaje
autofocus=True, # Enfocar automáticamente el campo de entrada de texto
fill_height=True, # Rellenar el espacio verticalmente
)
with gr.Row(): # Fila para los botones de cambio de pestaña
with gr.Column(): # Columna para el botón del predictor
gr.Markdown("<h2>Predictor</h2>") # Título de la pestaña del chatbot
boton_predictor = gr.Button("Probar predictor") # Botón para cambiar a la pestaña del predictor
with gr.Column(): # Columna para el botón del monitor
gr.Markdown("<h2>Monitor</h2>") # Título de la pestaña del chatbot
boton_monitor = gr.Button("Probar monitor") # Botón para cambiar a la pestaña del monitor
boton_volver_inicio = gr.Button("Volver al inicio") # Botón para volver a la pestaña inicial
with gr.Column(visible=False) as pag_predictor: # Columna para la pestaña del predictor
gr.Markdown("<h2>Predictor</h2>") # Título de la pestaña del predictor
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Descubre por qué tu bebé está llorando</h4>") # Descripción de la pestaña del predictor
audio_input = gr.Audio(
min_length=1.0, # Duración mínima del audio requerida
format="wav", # Formato de audio admitido
label="Baby recorder", # Etiqueta del campo de entrada de audio
type="filepath", # Tipo de entrada de audio (archivo)
)
prediction_output = gr.Markdown() # Salida para mostrar la predicción
gr.Button("¿Por qué llora?").click(
display_prediction_wrapper, # Función de predicción para el botón
inputs=audio_input, # Entrada de audio para la función de predicción
outputs=gr.Markdown() # Salida para mostrar la predicción
)
gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot]) # Botón para volver a la pestaña del chatbot
with gr.Column(visible=False) as pag_monitor: # Columna para la pestaña del monitor
gr.Markdown("<h2>Monitor</h2>") # Título de la pestaña del monitor
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Detecta en tiempo real si tu bebé está llorando y por qué</h4>") # Descripción de la pestaña del monitor
audio_stream = gr.Audio(
format="wav", # Formato de audio admitido
label="Baby recorder", # Etiqueta del campo de entrada de audio
type="filepath", # Tipo de entrada de audio (archivo)
streaming=True # Habilitar la transmisión de audio en tiempo real
)
threshold_db = gr.Slider(
minimum=0, # Valor mínimo del umbral de ruido
maximum=120, # Valor máximo del umbral de ruido
step=1, # Paso del umbral de ruido
value=20, # Valor inicial del umbral de ruido
label="Umbral de ruido para activar la predicción:" # Etiqueta del control deslizante del umbral de ruido
)
volver = gr.Button("Volver") # Botón para volver a la pestaña del chatbot
audio_stream.stream(
mostrar_decibelios, # Función para mostrar el nivel de decibelios
inputs=[audio_stream, threshold_db], # Entradas para la función de mostrar decibelios
outputs=gr.HTML() # Salida para mostrar el nivel de decibelios
)
audio_stream.stream(
predict_stream_decib, # Función para realizar la predicción en tiempo real
inputs=[audio_stream, threshold_db], # Entradas para la función de predicción en tiempo real
outputs=gr.HTML() # Salida para mostrar la predicción en tiempo real
)
volver.click(cambiar_pestaña, outputs=[pag_monitor, chatbot]) # Botón para volver a la pestaña del chatbot
boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot]) # Botón para cambiar a la pestaña inicial
boton_volver_inicio.click(cambiar_pestaña, outputs=[chatbot, inicial]) # Botón para volver a la pestaña inicial desde el chatbot
boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor]) # Botón para cambiar a la pestaña del predictor
boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor]) # Botón para cambiar a la pestaña del monitor
demo.launch(share=True) # Lanzar la interfaz gráfica
|