import gradio as gr import torch import soundfile as sf import os import matplotlib.pyplot as plt import numpy as np import os import soundfile as sf import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification from sklearn.model_selection import train_test_split import re from collections import Counter from sklearn.metrics import classification_report model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) model_path = "dysarthria_classifier12.pth" if os.path.exists(model_path): print(f"Loading saved model {model_path}") model.load_state_dict(torch.load(model_path)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") def predict(file_path): max_length = 100000 model.eval() with torch.no_grad(): wav_data, _ = sf.read(file_path.name) inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) input_values = inputs.input_values.squeeze(0) if max_length - input_values.shape[-1] > 0: input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1) else: input_values = input_values[:max_length] input_values = input_values.unsqueeze(0).to(device) inputs = {"input_values": input_values} logits = model(**inputs).logits logits = logits.squeeze() predicted_class_id = torch.argmax(logits, dim=-1).item() return predicted_class_id iface = gr.Interface(fn=predict, inputs="file", outputs="text") iface.launch()