Marcos12886 commited on
Commit
0433291
1 Parent(s): b9984fe

Delete mini_app.py

Browse files
Files changed (1) hide show
  1. mini_app.py +0 -47
mini_app.py DELETED
@@ -1,47 +0,0 @@
1
- import torch
2
- import gradio as gr
3
- from model import predict_params, AudioDataset
4
- from interfaz import estilo, my_theme
5
-
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
- model_class, id2label_class = predict_params(model_path="distilhubert-finetuned-mixed-data", dataset_path="data/mixed_data", filter_white_noise=True)
8
-
9
- def call(audiopath, model, dataset_path, filter_white_noise):
10
- model.to(device)
11
- model.eval()
12
- audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise,)
13
- processed_audio = audio_dataset.preprocess_audio(audiopath)
14
- inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
15
- with torch.no_grad():
16
- outputs = model(**inputs)
17
- logits = outputs.logits
18
- return logits
19
-
20
- def predict(audio_path_pred):
21
- with torch.no_grad():
22
- logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True)
23
- predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
24
- label_class = id2label_class[predicted_class_ids_class]
25
- label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
26
- label_class = label_mapping.get(predicted_class_ids_class, label_class)
27
- return label_class
28
-
29
- def cambiar_pestaña():
30
- return gr.update(visible=False), gr.update(visible=True)
31
-
32
- with gr.Blocks(theme=my_theme) as demo:
33
- estilo()
34
- with gr.Column(visible=False) as pag_predictor:
35
- gr.Markdown("<h2>Predictor</h2>")
36
- audio_input = gr.Audio(
37
- min_length=1.0,
38
- format="wav",
39
- label="Baby recorder",
40
- type="filepath",
41
- )
42
- gr.Button("¿Por qué llora?").click(
43
- predict,
44
- inputs=audio_input,
45
- outputs=gr.Textbox(label="Tu bebé llora por:")
46
- )
47
- demo.launch(share=True)