import torch import gradio as gr from model import predict_params, AudioDataset from interfaz import estilo, my_theme device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_class, id2label_class = predict_params(model_path="distilhubert-finetuned-mixed-data", dataset_path="data/mixed_data", filter_white_noise=True) def call(audiopath, model, dataset_path, filter_white_noise): model.to(device) model.eval() audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise,) processed_audio = audio_dataset.preprocess_audio(audiopath) inputs = {"input_values": processed_audio.to(device).unsqueeze(0)} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits return logits def predict(audio_path_pred): with torch.no_grad(): logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True) predicted_class_ids_class = torch.argmax(logits, dim=-1).item() label_class = id2label_class[predicted_class_ids_class] label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'} label_class = label_mapping.get(predicted_class_ids_class, label_class) return label_class def cambiar_pestaña(): return gr.update(visible=False), gr.update(visible=True) with gr.Blocks(theme=my_theme) as demo: estilo() with gr.Column(visible=False) as pag_predictor: gr.Markdown("

Predictor

") audio_input = gr.Audio( min_length=1.0, format="wav", label="Baby recorder", type="filepath", ) gr.Button("¿Por qué llora?").click( predict, inputs=audio_input, outputs=gr.Textbox(label="Tu bebé llora por:") ) demo.launch(share=True)