wav2vec / app.py
spycoder's picture
Update app.py
ba42b9f
raw
history blame
1.82 kB
import gradio as gr
import torch
import soundfile as sf
import os
import matplotlib.pyplot as plt
import numpy as np
import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.model_selection import train_test_split
import re
from collections import Counter
from sklearn.metrics import classification_report
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
model_path = "dysarthria_classifier12.pth"
if os.path.exists(model_path):
print(f"Loading saved model {model_path}")
model.load_state_dict(torch.load(model_path))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
def predict(file_path):
max_length = 100000
model.eval()
with torch.no_grad():
wav_data, _ = sf.read(file_path.name)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0)
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
inputs = {"input_values": input_values}
logits = model(**inputs).logits
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1).item()
return predicted_class_id
iface = gr.Interface(fn=predict, inputs="file", outputs="text")
iface.launch()