CHATBOT / app.py
Marcos12886's picture
Update app.py
2fc4b9d verified
import os
import torch
import torch.nn.functional as F # Importa la API funcional de torch, incluyendo softmax
import gradio as gr # Gradio para crear interfaces web
from huggingface_hub import InferenceClient # Cliente de inferencia para acceder a modelos desde Hugging Face Hub
from model import predict_params, AudioDataset # Importaciones personalizadas: carga de modelo y procesamiento de audio
import torchaudio # Librería para procesamiento de audio
token = os.getenv("HF_TOKEN") # Obtiene el token de la API de Hugging Face desde las variables de entorno
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token) # Inicializa el cliente de Hugging Face con el modelo y el token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Verifica si hay GPU disponible, de lo contrario usa CPU
model_class, id2label_class = predict_params(
model_path="A-POR-LOS-8000/distilhubert-finetuned-mixed-data2", # Ruta al modelo para la predicción de clases de llanto
dataset_path="data/mixed_data", # Ruta al dataset de audio mixto
filter_white_noise=True, # Indica que se filtrará el ruido blanco
undersample_normal=True # Activa el submuestreo para equilibrar clases
)
model_mon, id2label_mon = predict_params(
model_path="A-POR-LOS-8000/distilhubert-finetuned-cry-detector", # Ruta al modelo detector de llanto
dataset_path="data/baby_cry_detection", # Ruta al dataset de detección de llanto
filter_white_noise=False, # No filtrar ruido blanco en este modelo
undersample_normal=False # No submuestrear datos
)
def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal=False):
model.to(device) # Envía el modelo a la GPU (o CPU si no hay GPU disponible)
model.eval() # Pone el modelo en modo de evaluación (desactiva dropout, batchnorm)
audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal) # Carga el dataset de audio con parámetros específicos
processed_audio = audio_dataset.preprocess_audio(audiopath) # Preprocesa el audio según la configuración del dataset
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)} # Prepara los datos para el modelo (envía a GPU y ajusta dimensiones)
with torch.no_grad(): # Desactiva el cálculo del gradiente para ahorrar memoria
outputs = model(**inputs) # Realiza la inferencia con el modelo
logits = outputs.logits # Obtiene las predicciones del modelo
return logits # Retorna los logits (valores sin procesar)
def predict(audio_path_pred):
with torch.no_grad(): # Desactiva gradientes para la inferencia
logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=False) # Llama a la función de inferencia
predicted_class_ids_class = torch.argmax(logits, dim=-1).item() # Obtiene la clase predicha a partir de los logits
label_class = id2label_class[predicted_class_ids_class] # Convierte el ID de clase en una etiqueta de texto
label_mapping = {0: 'Cansancio/Incomodidad', 1: 'Dolor', 2: 'Hambre', 3: 'Problemas para respirar'} # Mapea las etiquetas
label_class = label_mapping.get(predicted_class_ids_class, label_class) # Si hay una etiqueta personalizada, la usa
return f"""
<div style='text-align: center; font-size: 1.5em'>
<span style='display: inline-block; min-width: 300px;'>{label_class}</span>
</div>
""" # Retorna el resultado formateado para mostrar en la interfaz
def predict_stream(audio_path_stream):
with torch.no_grad(): # Desactiva gradientes durante la inferencia
logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False, undersample_normal=False) # Llama al modelo de detección de llanto
probabilities = F.softmax(logits, dim=-1) # Aplica softmax para convertir los logits en probabilidades
crying_probabilities = probabilities[:, 1] # Obtiene las probabilidades asociadas al llanto
avg_crying_probability = crying_probabilities.mean()*100 # Calcula la probabilidad promedio de llanto
if avg_crying_probability < 15: # Si la probabilidad de llanto es menor a un 15%, se predice la razón
label_class = predict(audio_path_stream) # Llama a la predicción para determinar la razón del llanto
return f"Está llorando por: {label_class}" # Retorna el resultado indicando por qué llora
else:
return "No está llorando" # Si la probabilidad es mayor, indica que no está llorando
def decibelios(audio_path_stream):
waveform, _ = torchaudio.load(audio_path_stream) # Carga el audio y su forma de onda
rms = torch.sqrt(torch.mean(torch.square(waveform))) # Calcula el valor RMS del audio
db_level = 20 * torch.log10(rms + 1e-6).item() # Convierte el RMS en decibelios (añade un pequeño valor para evitar log(0))
min_db = -80 # Nivel mínimo de decibelios esperado
max_db = 0 # Nivel máximo de decibelios esperado
scaled_db_level = (db_level - min_db) / (max_db - min_db) # Escala el nivel de decibelios a un rango entre 0 y 1
normalized_db_level = scaled_db_level * 100 # Escala el nivel de decibelios a un porcentaje
return normalized_db_level # Retorna el nivel de decibelios normalizado
def mostrar_decibelios(audio_path_stream, visual_threshold):
db_level = decibelios(audio_path_stream)# Obtiene el nivel de decibelios del audio
if db_level > visual_threshold: # Si el nivel de decibelios supera el umbral visual
status = "Prediciendo..." # Cambia el estado a "Prediciendo"
else:
status = "Esperando..." # Si no supera el umbral, indica que está "Esperando"
return f"""
<div style='text-align: center; font-size: 1.5em'>
<span>{status}</span>
<span style='display: inline-block; min-width: 120px;'>Decibelios: {db_level:.2f}</span>
</div>
""" # Retorna una cadena HTML con el estado y el nivel de decibelios
def predict_stream_decib(audio_path_stream, visual_threshold):
db_level = decibelios(audio_path_stream) # Calcula el nivel de decibelios
if db_level > visual_threshold: # Si supera el umbral, hace una predicción
prediction = display_prediction_stream(audio_path_stream) # Llama a la función de predicción
else:
prediction = "" # Si no supera el umbral, no muestra predicción
return f"""
<div style='text-align: center; font-size: 1.5em; min-height: 2em;'>
<span style='display: inline-block; min-width: 300px;'>{prediction}</span>
</div>
""" # Retorna el resultado o nada si no supera el umbral
def chatbot_config(message, history: list[tuple[str, str]]):
system_message = "You are a Chatbot specialized in baby health and care." # Mensaje inicial del chatbot
max_tokens = 512 # Máximo de tokens para la respuesta
temperature = 0.5 # Controla la aleatoriedad de las respuestas
top_p = 0.95 # Top-p sampling para filtrar palabras
messages = [{"role": "system", "content": system_message}] # Configura el mensaje del sistema para el chatbot
for val in history: # Añade el historial de la conversación al mensaje
if val[0]:
messages.append({"role": "user", "content": val[0]}) # Añade los mensajes del usuario
if val[1]:
messages.append({"role": "assistant", "content": val[1]}) # Añade las respuestas del asistente
messages.append({"role": "user", "content": message}) # Añade el mensaje actual del usuario
response = "" # Inicializa la variable de respuesta
for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
token = message_response.choices[0].delta.content # Obtiene el contenido del mensaje generado por el modelo
response += token # Acumula el contenido generado en la respuesta final
return response # Retorna la respuesta generada por el modelo
def cambiar_pestaña():
return gr.update(visible=False), gr.update(visible=True) # Esta función cambia la visibilidad de las pestañas en la interfaz
def display_prediction(audio, prediction_func):
prediction = prediction_func(audio) # Llama a la función de predicción para obtener el resultado
return f"<h3 style='text-align: center; font-size: 1.5em;'>{prediction}</h3>" # Retorna el resultado formateado en HTML
def display_prediction_wrapper(audio):
return display_prediction(audio, predict) # Envuelve la función de predicción "predict" en la función "display_prediction"
def display_prediction_stream(audio):
return display_prediction(audio, predict_stream) # Envuelve la función de predicción "predict_stream" en la función "display_prediction"
my_theme = gr.themes.Soft(
primary_hue="emerald",
secondary_hue="green",
neutral_hue="slate",
text_size="sm",
spacing_size="sm",
font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
).set(
body_background_fill='*neutral_50',
body_text_color='*neutral_600',
body_text_size='*text_sm',
embed_radius='*radius_md',
shadow_drop='*shadow_spread',
shadow_spread='*button_shadow_active'
)
with gr.Blocks(theme=my_theme, fill_height=True, fill_width=True) as demo:
with gr.Column(visible=True) as inicial:
gr.HTML(
"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
@import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
h1 {
font-family: 'Lobster', cursive;
font-size: 5em !important;
text-align: center;
margin: 0;
}
.gr-button {
background-color: #4CAF50 !important;
color: white !important;
border: none;
padding: 25px 50px;
text-align: center;
text-decoration: none;
display: inline-block;
font-family: 'Lobster', cursive;
font-size: 2em !important;
margin: 4px 2px;
cursor: pointer;
border-radius: 12px;
}
.gr-button:hover {
background-color: #45a049;
}
h2 {
font-family: 'Lobster', cursive;
font-size: 3em !important;
text-align: center;
margin: 0;
}
p.slogan, h4, p, h3 {
font-family: 'Roboto', sans-serif;
text-align: center;
}
</style>
<h1>Iremia</h1>
<h4 style='text-align: center; font-size: 1.5em'>El mejor aliado para el bienestar de tu bebé</h4>
"""
)
gr.Markdown(
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
"<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
"<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
"<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
"<p style='text-align: left'>Chatbot: Pregunta a nuestro asistente que te ayudará con cualquier duda que tengas sobre el cuidado de tu bebé.</p>"
"<p style='text-align: left'>Predictor: Con nuestro modelo de inteligencia artificial somos capaces de predecir por qué tu bebé está llorando.</p>"
"<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no; y si está llorando, predice automáticamente la causa. Dándote la tranquilidad de saber siempre qué pasa con tu pequeño, ahorrándote tiempo y horas de sueño.</p>"
)
boton_inicial = gr.Button("¡Prueba nuestros modelos!")
with gr.Column(visible=False) as chatbot: # Columna para la pestaña del chatbot
gr.Markdown("<h2>Asistente</h2>") # Título de la pestaña del chatbot
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Pregunta a nuestro asistente cualquier duda que tengas sobre el cuidado de tu bebé</h4>") # Descripción de la pestaña del chatbot
gr.ChatInterface(
chatbot_config, # Función de configuración del chatbot
theme=my_theme, # Tema personalizado para la interfaz
retry_btn=None, # Botón de reintentar desactivado
undo_btn=None, # Botón de deshacer desactivado
clear_btn="Limpiar 🗑️", # Botón de limpiar mensajes
submit_btn="Enviar", # Botón de enviar mensaje
autofocus=True, # Enfocar automáticamente el campo de entrada de texto
fill_height=True, # Rellenar el espacio verticalmente
)
with gr.Row(): # Fila para los botones de cambio de pestaña
with gr.Column(): # Columna para el botón del predictor
gr.Markdown("<h2>Predictor</h2>") # Título de la pestaña del chatbot
boton_predictor = gr.Button("Probar predictor") # Botón para cambiar a la pestaña del predictor
with gr.Column(): # Columna para el botón del monitor
gr.Markdown("<h2>Monitor</h2>") # Título de la pestaña del chatbot
boton_monitor = gr.Button("Probar monitor") # Botón para cambiar a la pestaña del monitor
boton_volver_inicio = gr.Button("Volver al inicio") # Botón para volver a la pestaña inicial
with gr.Column(visible=False) as pag_predictor: # Columna para la pestaña del predictor
gr.Markdown("<h2>Predictor</h2>") # Título de la pestaña del predictor
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Descubre por qué tu bebé está llorando</h4>") # Descripción de la pestaña del predictor
audio_input = gr.Audio(
min_length=1.0, # Duración mínima del audio requerida
format="wav", # Formato de audio admitido
label="Baby recorder", # Etiqueta del campo de entrada de audio
type="filepath", # Tipo de entrada de audio (archivo)
)
prediction_output = gr.Markdown() # Salida para mostrar la predicción
gr.Button("¿Por qué llora?").click(
display_prediction_wrapper, # Función de predicción para el botón
inputs=audio_input, # Entrada de audio para la función de predicción
outputs=gr.Markdown() # Salida para mostrar la predicción
)
gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot]) # Botón para volver a la pestaña del chatbot
with gr.Column(visible=False) as pag_monitor: # Columna para la pestaña del monitor
gr.Markdown("<h2>Monitor</h2>") # Título de la pestaña del monitor
gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Detecta en tiempo real si tu bebé está llorando y por qué</h4>") # Descripción de la pestaña del monitor
audio_stream = gr.Audio(
format="wav", # Formato de audio admitido
label="Baby recorder", # Etiqueta del campo de entrada de audio
type="filepath", # Tipo de entrada de audio (archivo)
streaming=True # Habilitar la transmisión de audio en tiempo real
)
threshold_db = gr.Slider(
minimum=0, # Valor mínimo del umbral de ruido
maximum=120, # Valor máximo del umbral de ruido
step=1, # Paso del umbral de ruido
value=20, # Valor inicial del umbral de ruido
label="Umbral de ruido para activar la predicción:" # Etiqueta del control deslizante del umbral de ruido
)
volver = gr.Button("Volver") # Botón para volver a la pestaña del chatbot
audio_stream.stream(
mostrar_decibelios, # Función para mostrar el nivel de decibelios
inputs=[audio_stream, threshold_db], # Entradas para la función de mostrar decibelios
outputs=gr.HTML() # Salida para mostrar el nivel de decibelios
)
audio_stream.stream(
predict_stream_decib, # Función para realizar la predicción en tiempo real
inputs=[audio_stream, threshold_db], # Entradas para la función de predicción en tiempo real
outputs=gr.HTML() # Salida para mostrar la predicción en tiempo real
)
volver.click(cambiar_pestaña, outputs=[pag_monitor, chatbot]) # Botón para volver a la pestaña del chatbot
boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot]) # Botón para cambiar a la pestaña inicial
boton_volver_inicio.click(cambiar_pestaña, outputs=[chatbot, inicial]) # Botón para volver a la pestaña inicial desde el chatbot
boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor]) # Botón para cambiar a la pestaña del predictor
boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor]) # Botón para cambiar a la pestaña del monitor
demo.launch(share=True) # Lanzar la interfaz gráfica