Marcos12886 commited on
Commit
763091b
1 Parent(s): 1565b0a

TODO FUNCIONANDO. Igual que github

Browse files
Files changed (3) hide show
  1. README.md +31 -13
  2. app.py +107 -32
  3. model.py +19 -28
README.md CHANGED
@@ -1,13 +1,31 @@
1
- ---
2
- title: CHATBOT
3
- emoji: 🔥
4
- colorFrom: pink
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.42.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Instalación
2
+ La instalación y uso están pensados para una gráfica NVIDIA decentilla. Si no dispones de una gráfica NVIDIA, ejecutar en las gráficas de Colab.
3
+
4
+ Instalaciones necesarias para local:
5
+ - pip install transformers[torch] gradio tensorboardX scikit-learn
6
+ - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
7
+
8
+ #### GitHub
9
+ En el archivo .gitignore están puestas las carpetas que no se deben subir a github.
10
+
11
+ ## Estructura
12
+ Dos funcionalidades:
13
+ - Monitor de bebés: identificar si tu bebé llora y por qué
14
+ - Clasificador de llantos: conocer por qué llora tu bebé
15
+ - Chatbot: poder hablar con llama 3 8B sobre las preocupaciones con tu bebé
16
+
17
+ Flujo de archivos:
18
+ 1. Construir la estructura de los modelos y entrenarlos [model.py](model.py)
19
+ 2. Chatbot en el que grabar audio y conectar con el llm [app.py](app.py)
20
+
21
+ Un modelo ([model.py](model.py)) entrenado con distintos datos:
22
+ - Modelo para monitorizar: --n monitor
23
+ - Modelo clasificador de llantos: --n class
24
+
25
+ Chatbot [app.py](app.py)
26
+
27
+ ### Datos utilizados
28
+ - https://data.mendeley.com/datasets/hbppd883sd/1
29
+ - https://zenodo.org/records/2535878
30
+ - https://paperswithcode.com/dataset/esc50
31
+ - https://osf.io/usr8d
app.py CHANGED
@@ -3,25 +3,24 @@ import torch
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from model import predict_params, AudioDataset
6
- from interfaz import estilo, my_theme
7
-
8
  token = os.getenv("HF_TOKEN")
9
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model_class, id2label_class = predict_params(
12
- model_path="A-POR-LOS-8000/distilhubert-finetuned-mixed-data",
13
  dataset_path="data/mixed_data",
14
  filter_white_noise=True,
15
  undersample_normal=True
16
  )
17
  model_mon, id2label_mon = predict_params(
18
- model_path="A-POR-LOS-8000/distilhubert-finetuned-cry-detector",
19
  dataset_path="data/baby_cry_detection",
20
  filter_white_noise=False,
21
  undersample_normal=False
22
  )
23
 
24
- def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal):
25
  model.to(device)
26
  model.eval()
27
  audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal)
@@ -34,10 +33,10 @@ def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal)
34
 
35
  def predict(audio_path_pred):
36
  with torch.no_grad():
37
- logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=True)
38
  predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
39
  label_class = id2label_class[predicted_class_ids_class]
40
- label_mapping = {0: 'Dolor', 1: 'Cansancio/Incomodidad', 2: 'Hambre', 3: 'Problemas para respirar'}
41
  label_class = label_mapping.get(predicted_class_ids_class, label_class)
42
  return label_class
43
 
@@ -49,9 +48,9 @@ def predict_stream(audio_path_stream):
49
  avg_crying_probability = crying_probabilities.mean()*100
50
  if avg_crying_probability < 15:
51
  label_class = predict(audio_path_stream)
52
- return "Está llorando por:", f"{label_class}. Probabilidad: {avg_crying_probability:.1f}%"
53
  else:
54
- return "No está llorando.", f"Probabilidad: {avg_crying_probability:.1f}%"
55
 
56
  def decibelios(audio_path_stream):
57
  with torch.no_grad():
@@ -70,15 +69,15 @@ def mostrar_decibelios(audio_path_stream, visual_threshold):
70
  def predict_stream_decib(audio_path_stream, visual_threshold):
71
  db_level = decibelios(audio_path_stream)
72
  if db_level < visual_threshold:
73
- llorando, probabilidad = predict_stream(audio_path_stream)
74
- return f"{llorando} {probabilidad}"
75
  else:
76
  return ""
77
 
78
  def chatbot_config(message, history: list[tuple[str, str]]):
79
  system_message = "You are a Chatbot specialized in baby health and care."
80
  max_tokens = 512
81
- temperature = 0.7
82
  top_p = 0.95
83
  messages = [{"role": "system", "content": system_message}]
84
  for val in history:
@@ -96,25 +95,100 @@ def chatbot_config(message, history: list[tuple[str, str]]):
96
  def cambiar_pestaña():
97
  return gr.update(visible=False), gr.update(visible=True)
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  with gr.Blocks(theme=my_theme) as demo:
100
- estilo()
101
- with gr.Column(visible=True) as chatbot:
102
- gr.Markdown("<h2>Asistente</h2>")
103
- gr.ChatInterface(
104
- chatbot_config # TODO: Mirar argumentos
105
- )
106
- gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
107
- with gr.Row():
108
- with gr.Column():
109
- gr.Markdown("<h2>Predictor</h2>")
110
- boton_predictor = gr.Button("Prueba el predictor")
111
- gr.Markdown("<p>Descubre por qué llora tu bebé</p>")
112
- with gr.Column():
113
- gr.Markdown("<h2>Monitor</h2>")
114
- boton_monitor = gr.Button("Prueba el monitor")
115
- gr.Markdown("<p>Monitoriza si tu hijo está llorando y por qué, sin levantarte del sofá</p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  with gr.Column(visible=False) as pag_predictor:
117
- gr.Markdown("<h2>Predictor</h2>")
118
  audio_input = gr.Audio(
119
  min_length=1.0,
120
  format="wav",
@@ -126,7 +200,7 @@ with gr.Blocks(theme=my_theme) as demo:
126
  inputs=audio_input,
127
  outputs=gr.Textbox(label="Tu bebé llora por:")
128
  )
129
- gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
130
  with gr.Column(visible=False) as pag_monitor:
131
  gr.Markdown("<h2>Monitor</h2>")
132
  audio_stream = gr.Audio(
@@ -140,7 +214,7 @@ with gr.Blocks(theme=my_theme) as demo:
140
  maximum=100,
141
  step=1,
142
  value=30,
143
- label="Umbral de dB para activar la predicción"
144
  )
145
  audio_stream.stream(
146
  mostrar_decibelios,
@@ -152,7 +226,8 @@ with gr.Blocks(theme=my_theme) as demo:
152
  inputs=[audio_stream, threshold_db],
153
  outputs=gr.Textbox(value="", label="Tu bebé:")
154
  )
155
- gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
 
156
  boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
157
  boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
158
  demo.launch(share=True)
 
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from model import predict_params, AudioDataset
6
+ # TODO: Que no diga lo de que no hay 1s_normal al predecir
 
7
  token = os.getenv("HF_TOKEN")
8
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model_class, id2label_class = predict_params(
11
+ model_path="distilhubert-finetuned-mixed-data",
12
  dataset_path="data/mixed_data",
13
  filter_white_noise=True,
14
  undersample_normal=True
15
  )
16
  model_mon, id2label_mon = predict_params(
17
+ model_path="distilhubert-finetuned-cry-detector",
18
  dataset_path="data/baby_cry_detection",
19
  filter_white_noise=False,
20
  undersample_normal=False
21
  )
22
 
23
+ def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal=False):
24
  model.to(device)
25
  model.eval()
26
  audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal)
 
33
 
34
  def predict(audio_path_pred):
35
  with torch.no_grad():
36
+ logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=False)
37
  predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
38
  label_class = id2label_class[predicted_class_ids_class]
39
+ label_mapping = {0: 'Cansancio/Incomodidad', 1: 'Dolor', 2: 'Hambre', 3: 'Problemas para respirar'}
40
  label_class = label_mapping.get(predicted_class_ids_class, label_class)
41
  return label_class
42
 
 
48
  avg_crying_probability = crying_probabilities.mean()*100
49
  if avg_crying_probability < 15:
50
  label_class = predict(audio_path_stream)
51
+ return f"Está llorando por: {label_class}"
52
  else:
53
+ return "No está llorando."
54
 
55
  def decibelios(audio_path_stream):
56
  with torch.no_grad():
 
69
  def predict_stream_decib(audio_path_stream, visual_threshold):
70
  db_level = decibelios(audio_path_stream)
71
  if db_level < visual_threshold:
72
+ llorando = predict_stream(audio_path_stream)
73
+ return f"{llorando}"
74
  else:
75
  return ""
76
 
77
  def chatbot_config(message, history: list[tuple[str, str]]):
78
  system_message = "You are a Chatbot specialized in baby health and care."
79
  max_tokens = 512
80
+ temperature = 0.5
81
  top_p = 0.95
82
  messages = [{"role": "system", "content": system_message}]
83
  for val in history:
 
95
  def cambiar_pestaña():
96
  return gr.update(visible=False), gr.update(visible=True)
97
 
98
+ my_theme = gr.themes.Soft(
99
+ primary_hue="emerald",
100
+ secondary_hue="green",
101
+ neutral_hue="slate",
102
+ text_size="sm",
103
+ spacing_size="sm",
104
+ font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
105
+ font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
106
+ ).set(
107
+ body_background_fill='*neutral_50',
108
+ body_text_color='*neutral_600',
109
+ body_text_size='*text_sm',
110
+ embed_radius='*radius_md',
111
+ shadow_drop='*shadow_spread',
112
+ shadow_spread='*button_shadow_active'
113
+ )
114
+
115
  with gr.Blocks(theme=my_theme) as demo:
116
+ with gr.Column(visible=True) as inicial:
117
+ gr.HTML(
118
+ """
119
+ <style>
120
+ @import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
121
+ @import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
122
+
123
+ h1 {
124
+ font-family: 'Lobster', cursive;
125
+ font-size: 5em !important;
126
+ text-align: center;
127
+ margin: 0;
128
+ }
129
+
130
+ .gr-button {
131
+ background-color: #4CAF50 !important;
132
+ color: white !important;
133
+ border: none;
134
+ padding: 25px 50px; /* Increase the padding for bigger buttons */
135
+ text-align: center;
136
+ text-decoration: none;
137
+ display: inline-block;
138
+ font-family: 'Lobster', cursive; /* Apply the Lobster font */
139
+ font-size: 2em !important; /* Increase the button text size */
140
+ margin: 4px 2px;
141
+ cursor: pointer;
142
+ border-radius: 12px;
143
+ }
144
+
145
+ .gr-button:hover {
146
+ background-color: #45a049;
147
+ }
148
+ h2 {
149
+ font-family: 'Lobster', cursive;
150
+ font-size: 3em !important;
151
+ text-align: center;
152
+ margin: 0;
153
+ }
154
+ p.slogan, h4, p, h3 {
155
+ font-family: 'Roboto', sans-serif;
156
+ text-align: center;
157
+ }
158
+ </style>
159
+ <h1>Iremia</h1>
160
+ <h4 style='text-align: center; font-size: 1.5em'>Tu aliado para el bienestar de tu bebé</h4>
161
+ """
162
+ )
163
+ gr.Markdown(
164
+ "<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
165
+ "<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
166
+ "<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
167
+ "<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
168
+ "<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
169
+ "<p style='text-align: left'>Chatbot: Pregunta a nuestro asistente que te ayudará con cualquier duda que tengas sobre el cuidado de tu bebé.</p>"
170
+ "<p style='text-align: left'>Analizador: Con nuestro modelo de inteligencia artificial somos capaces de predecir por qué tu hijo de menos de 2 años está llorando.</p>"
171
+ "<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no; y si está llorando, predice automáticamente la causa. Dándote la tranquilidad de saber siempre qué pasa con tu pequeño, ahorrándote tiempo y horas de sueño.</p>"
172
+ )
173
+ boton_inicial = gr.Button("Comenzar")
174
+ with gr.Column(visible=False) as chatbot:
175
+ gr.Markdown("<h2>Asistente</h2>")
176
+ gr.ChatInterface(
177
+ chatbot_config,
178
+ theme=my_theme,
179
+ retry_btn=None,
180
+ undo_btn=None,
181
+ clear_btn="Limpiar 🗑️",
182
+ autofocus=True,
183
+ fill_height=True,
184
+ )
185
+ with gr.Row():
186
+ with gr.Column():
187
+ boton_predictor = gr.Button("Analizador")
188
+ with gr.Column():
189
+ boton_monitor = gr.Button("Monitor")
190
  with gr.Column(visible=False) as pag_predictor:
191
+ gr.Markdown("<h2>Analizador</h2>")
192
  audio_input = gr.Audio(
193
  min_length=1.0,
194
  format="wav",
 
200
  inputs=audio_input,
201
  outputs=gr.Textbox(label="Tu bebé llora por:")
202
  )
203
+ gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
204
  with gr.Column(visible=False) as pag_monitor:
205
  gr.Markdown("<h2>Monitor</h2>")
206
  audio_stream = gr.Audio(
 
214
  maximum=100,
215
  step=1,
216
  value=30,
217
+ label="Decibelios para activar la predicción:"
218
  )
219
  audio_stream.stream(
220
  mostrar_decibelios,
 
226
  inputs=[audio_stream, threshold_db],
227
  outputs=gr.Textbox(value="", label="Tu bebé:")
228
  )
229
+ gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
230
+ boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot])
231
  boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
232
  boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
233
  demo.launch(share=True)
model.py CHANGED
@@ -30,7 +30,7 @@ class AudioDataset(Dataset):
30
  self.dataset_path = dataset_path
31
  self.label2id = label2id
32
  self.file_paths = []
33
- self.filter_white_noise = filter_white_noise # Changed this line
34
  self.labels = []
35
  for label_dir, label_id in self.label2id.items():
36
  label_path = os.path.join(self.dataset_path, label_dir)
@@ -39,33 +39,25 @@ class AudioDataset(Dataset):
39
  audio_path = os.path.join(label_path, file_name)
40
  self.file_paths.append(audio_path)
41
  self.labels.append(label_id)
42
- if undersample_normal:
43
  self.undersample_normal_class()
44
 
45
  def undersample_normal_class(self):
46
  normal_label = self.label2id.get('1s_normal')
47
- if normal_label is None:
48
- print("Warning: No '1s_normal' class found. Skipping undersampling.")
49
- return
50
  label_counts = Counter(self.labels)
51
  other_counts = [count for label, count in label_counts.items() if label != normal_label]
52
- if not other_counts:
53
- print("Warning: No non-normal classes found. Skipping undersampling.")
54
- return
55
- target_count = max(other_counts)
56
- normal_indices = [i for i, label in enumerate(self.labels) if label == normal_label]
57
- if len(normal_indices) <= target_count:
58
- print("Warning: Normal class count is already <= other class counts. Skipping undersampling.")
59
- return
60
- keep_indices = random.sample(normal_indices, target_count)
61
- new_file_paths = []
62
- new_labels = []
63
- for i, (path, label) in enumerate(zip(self.file_paths, self.labels)):
64
- if label != normal_label or i in keep_indices:
65
- new_file_paths.append(path)
66
- new_labels.append(label)
67
- self.file_paths = new_file_paths
68
- self.labels = new_labels
69
 
70
  def __len__(self):
71
  return len(self.file_paths)
@@ -107,12 +99,11 @@ def is_white_noise(audio):
107
  std = torch.std(audio)
108
  return torch.abs(mean) < 0.001 and std < 0.01
109
 
110
- def seed_everything():
111
  torch.manual_seed(seed)
112
  torch.cuda.manual_seed(seed)
113
- torch.backends.cudnn.deterministic = True
114
- torch.backends.cudnn.benchmark = False
115
- os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16384:8'
116
 
117
  def build_label_mappings(dataset_path):
118
  label2id = {}
@@ -165,10 +156,10 @@ def load_model(model_path, id2label, num_labels):
165
  finetuning_task="audio-classification"
166
  )
167
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
- model = HubertForSequenceClassification.from_pretrained( # TODO: mirar parámetros. Posibles optimizaciones
169
  pretrained_model_name_or_path=model_path,
170
  config=config,
171
- torch_dtype=torch.float32,
172
  )
173
  model.to(device)
174
  return model
 
30
  self.dataset_path = dataset_path
31
  self.label2id = label2id
32
  self.file_paths = []
33
+ self.filter_white_noise = filter_white_noise
34
  self.labels = []
35
  for label_dir, label_id in self.label2id.items():
36
  label_path = os.path.join(self.dataset_path, label_dir)
 
39
  audio_path = os.path.join(label_path, file_name)
40
  self.file_paths.append(audio_path)
41
  self.labels.append(label_id)
42
+ if undersample_normal and self.label2id:
43
  self.undersample_normal_class()
44
 
45
  def undersample_normal_class(self):
46
  normal_label = self.label2id.get('1s_normal')
 
 
 
47
  label_counts = Counter(self.labels)
48
  other_counts = [count for label, count in label_counts.items() if label != normal_label]
49
+ if other_counts: # Ensure there are other counts before taking max
50
+ target_count = max(other_counts)
51
+ normal_indices = [i for i, label in enumerate(self.labels) if label == normal_label]
52
+ keep_indices = random.sample(normal_indices, target_count)
53
+ new_file_paths = []
54
+ new_labels = []
55
+ for i, (path, label) in enumerate(zip(self.file_paths, self.labels)):
56
+ if label != normal_label or i in keep_indices:
57
+ new_file_paths.append(path)
58
+ new_labels.append(label)
59
+ self.file_paths = new_file_paths
60
+ self.labels = new_labels
 
 
 
 
 
61
 
62
  def __len__(self):
63
  return len(self.file_paths)
 
99
  std = torch.std(audio)
100
  return torch.abs(mean) < 0.001 and std < 0.01
101
 
102
+ def seed_everything(): # TODO: mirar si es necesario algo más
103
  torch.manual_seed(seed)
104
  torch.cuda.manual_seed(seed)
105
+ # torch.backends.cudnn.deterministic = True # Para reproducibilidad
106
+ # torch.backends.cudnn.benchmark = False # Para reproducibilidad
 
107
 
108
  def build_label_mappings(dataset_path):
109
  label2id = {}
 
156
  finetuning_task="audio-classification"
157
  )
158
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
159
+ model = HubertForSequenceClassification.from_pretrained(
160
  pretrained_model_name_or_path=model_path,
161
  config=config,
162
+ torch_dtype=torch.float32, # TODO: Comprobar si se necesita float32 y ver si se puede cambiar por float16
163
  )
164
  model.to(device)
165
  return model