|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import WhisperModel, WhisperFeatureExtractor |
|
import datasets |
|
from datasets import load_dataset, DatasetDict, Audio |
|
from huggingface_hub import PyTorchModelHubMixin |
|
import numpy as np |
|
import librosa |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
config = {"encoder": "openai/whisper-base", "num_labels": 2} |
|
|
|
|
|
class SpeechInferenceDataset(Dataset): |
|
def __init__(self, audio_data, text_processor): |
|
self.audio_data = audio_data |
|
self.text_processor = text_processor |
|
|
|
def __len__(self): |
|
return len(self.audio_data) |
|
|
|
def __getitem__(self, index): |
|
inputs = self.text_processor(self.audio_data[index]["audio"]["array"], |
|
return_tensors="pt", |
|
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"]) |
|
input_features = inputs.input_features.squeeze(0) |
|
decoder_input_ids = torch.tensor([[1, 1]]) |
|
return input_features, decoder_input_ids |
|
|
|
|
|
class SpeechClassifier(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, config): |
|
super(SpeechClassifier, self).__init__() |
|
self.encoder = WhisperModel.from_pretrained(config["encoder"]) |
|
self.classifier = nn.Sequential( |
|
nn.Linear(self.encoder.config.hidden_size, 4096), |
|
nn.ReLU(), |
|
nn.Linear(4096, 2048), |
|
nn.ReLU(), |
|
nn.Linear(2048, 1024), |
|
nn.ReLU(), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Linear(512, config["num_labels"]) |
|
) |
|
|
|
def forward(self, input_features, decoder_input_ids): |
|
outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids) |
|
pooled_output = outputs['last_hidden_state'][:, 0, :] |
|
logits = self.classifier(pooled_output) |
|
return logits |
|
|
|
|
|
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): |
|
|
|
|
|
audio_data = audio_data.astype(np.float32) |
|
|
|
|
|
audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000) |
|
|
|
|
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) |
|
|
|
|
|
dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}], |
|
text_processor=feature_extractor) |
|
|
|
dataloader = DataLoader(dataset, batch_size=1) |
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
|
def predict(audio_data, sampling_rate, config): |
|
dataloader = prepare_data(audio_data, sampling_rate, config["encoder"]) |
|
|
|
model = SpeechClassifier(config).to(device) |
|
|
|
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
for input_features, decoder_input_ids in dataloader: |
|
input_features = input_features.to(device) |
|
decoder_input_ids = decoder_input_ids.to(device) |
|
logits = model(input_features, decoder_input_ids) |
|
predicted_ids = int(torch.argmax(logits, dim=-1)) |
|
return predicted_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_file_interface(uploaded_file): |
|
|
|
audio_data, sampling_rate = librosa.load(uploaded_file, sr=None) |
|
prediction = predict(audio_data, sampling_rate, config) |
|
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" |
|
return label |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_mic_interface(mic_input): |
|
|
|
|
|
prediction = predict(mic_input[1].astype(np.float32), mic_input[0], config) |
|
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" |
|
return label |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Tab("Upload File"): |
|
gr.Interface( |
|
fn=gradio_file_interface, |
|
inputs=gr.Audio(sources="upload", type="filepath"), |
|
outputs=gr.Textbox(label="Prediction") |
|
) |
|
|
|
|
|
with gr.Tab("Record Using Microphone"): |
|
gr.Interface( |
|
fn=gradio_mic_interface, |
|
inputs=gr.Audio(sources="microphone", type="numpy"), |
|
outputs=gr.Textbox(label="Prediction") |
|
) |
|
|
|
|
|
demo.launch(debug=True) |