File size: 6,662 Bytes
1e6dc54
 
5cf41d0
ace06e3
abdf62b
 
cc3562b
6d1143c
2ca1b49
166aa6c
f72ba5a
 
ace06e3
166aa6c
abdf62b
 
166aa6c
 
 
1e6dc54
 
abdf62b
166aa6c
1e6dc54
166aa6c
abdf62b
1d21972
166aa6c
 
 
 
 
 
 
 
1d21972
abdf62b
 
166aa6c
 
 
 
 
 
 
 
 
1d21972
166aa6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abdf62b
166aa6c
abdf62b
 
 
 
 
 
ace06e3
 
 
ebf42ac
 
ace06e3
 
 
abdf62b
 
ace06e3
 
 
1e6dc54
 
 
ace06e3
abdf62b
 
 
 
 
 
 
ebf42ac
 
 
166aa6c
 
ebf42ac
 
166aa6c
 
abdf62b
 
 
 
 
 
 
 
166aa6c
 
abdf62b
 
 
 
 
 
 
 
 
 
 
 
166aa6c
 
 
 
 
 
 
 
 
 
 
 
abdf62b
166aa6c
 
 
abdf62b
 
166aa6c
 
abdf62b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import torch
import gradio as gr
from huggingface_hub import InferenceClient
from model import predict_params, AudioDataset
from interfaz import estilo, my_theme

token = os.getenv("HF_TOKEN")
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_class, id2label_class = predict_params(model_path="A-POR-LOS-8000/distilhubert-finetuned-mixed-data", dataset_path="A-POR-LOS-8000/data/mixed_data", filter_white_noise=True)
model_mon, id2label_mon = predict_params(model_path="A-POR-LOS-8000/distilhubert-finetuned-cry-detector", dataset_path="A-POR-LOS-8000/data/baby_cry_detection", filter_white_noise=False)

def call(audiopath, model, dataset_path, filter_white_noise):
    model.to(device)
    model.eval()
    audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise,)
    processed_audio = audio_dataset.preprocess_audio(audiopath)
    inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    return logits

def predict(audio_path_pred):
    with torch.no_grad():
        logits = call(audio_path_pred, model=model_class, dataset_path="A-POR-LOS-8000/data/mixed_data", filter_white_noise=True)
        predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
        label_class = id2label_class[predicted_class_ids_class]
        label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
        label_class = label_mapping.get(predicted_class_ids_class, label_class)
    return label_class

def predict_stream(audio_path_stream):
    with torch.no_grad():
        logits = call(audio_path_stream, model=model_mon, dataset_path="A-POR-LOS-8000/data/baby_cry_detection", filter_white_noise=False)
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        crying_probabilities = probabilities[:, 1]
        avg_crying_probability = crying_probabilities.mean()*100
        if avg_crying_probability < 15:
            label_class = predict(audio_path_stream)
            return "Está llorando por:", f"{label_class}. Probabilidad: {avg_crying_probability:.1f}%"
        else:
            return "No está llorando.", f"Probabilidad: {avg_crying_probability:.1f}%"

def decibelios(audio_path_stream):
    with torch.no_grad():
        logits = call(audio_path_stream, model=model_mon, dataset_path="A-POR-LOS-8000/data/baby_cry_detection", filter_white_noise=False)
        rms = torch.sqrt(torch.mean(torch.square(logits)))
        db_level = 20 * torch.log10(rms + 1e-6).item()
        return db_level

def mostrar_decibelios(audio_path_stream, visual_threshold):
    db_level = decibelios(audio_path_stream)
    if db_level < visual_threshold:
        return f"Prediciendo. Decibelios: {db_level:.2f}"
    elif db_level > visual_threshold:
        return "No detectamos ruido..."

def predict_stream_decib(audio_path_stream, visual_threshold):
    db_level = decibelios(audio_path_stream)
    if db_level < visual_threshold:
        llorando, probabilidad = predict_stream(audio_path_stream)
        return f"{llorando} {probabilidad}"
    else:
        return ""

def chatbot_config(message, history: list[tuple[str, str]]):
    system_message = "You are a Chatbot specialized in baby health and care."
    max_tokens = 512
    temperature = 0.7
    top_p = 0.95
    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    messages.append({"role": "user", "content": message})
    response = ""
    for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
        token = message_response.choices[0].delta.content
        response += token
        yield response

def cambiar_pestaña():
    return gr.update(visible=False), gr.update(visible=True)

with gr.Blocks(theme=my_theme) as demo:
    estilo()
    with gr.Column(visible=True) as chatbot:    
        gr.Markdown("<h2>Asistente</h2>")
        gr.ChatInterface(
            chatbot_config # TODO: Mirar argumentos
            )
        gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
        with gr.Row():
            with gr.Column():
                gr.Markdown("<h2>Predictor</h2>")
                boton_predictor = gr.Button("Prueba el predictor")
                gr.Markdown("<p>Descubre por qué llora tu bebé</p>")
            with gr.Column():
                gr.Markdown("<h2>Monitor</h2>")
                boton_monitor = gr.Button("Prueba el monitor")
                gr.Markdown("<p>Monitoriza si tu hijo está llorando y por qué, sin levantarte del sofá</p>")
    with gr.Column(visible=False) as pag_predictor:
        gr.Markdown("<h2>Predictor</h2>")
        audio_input = gr.Audio(
            min_length=1.0,
            format="wav",
            label="Baby recorder",
            type="filepath",
            )
        gr.Button("¿Por qué llora?").click(
            predict,
            inputs=audio_input,
            outputs=gr.Textbox(label="Tu bebé llora por:")
            )
        gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
    with gr.Column(visible=False) as pag_monitor:
        gr.Markdown("<h2>Monitor</h2>")
        audio_stream = gr.Audio(
                format="wav",
                label="Baby recorder",
                type="filepath",
                streaming=True
            )
        threshold_db = gr.Slider(
            minimum=0,
            maximum=100,
            step=1,
            value=30,
            label="Umbral de dB para activar la predicción"
            )
        audio_stream.stream(
            mostrar_decibelios,
            inputs=[audio_stream, threshold_db],
            outputs=gr.Textbox(value="Esperando...", label="Estado")
            )
        audio_stream.stream(
            predict_stream_decib,
            inputs=[audio_stream, threshold_db],
            outputs=gr.Textbox(value="", label="Tu bebé:")
        )
        gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
    boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
    boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
demo.launch(share=True)