lilyhof's picture
Update app.py
603d981 verified
raw
history blame
4.76 kB
import gradio as gr
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperModel, WhisperFeatureExtractor
import datasets
from datasets import load_dataset, DatasetDict, Audio
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
import librosa
# Ensure you have the device setup (cuda or cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the config for your model
config = {"encoder": "openai/whisper-base", "num_labels": 2}
# Define data class
class SpeechInferenceDataset(Dataset):
def __init__(self, audio_data, text_processor):
self.audio_data = audio_data
self.text_processor = text_processor
def __len__(self):
return len(self.audio_data)
def __getitem__(self, index):
inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
return_tensors="pt",
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
return input_features, decoder_input_ids
# Define model class
class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(SpeechClassifier, self).__init__()
self.encoder = WhisperModel.from_pretrained(config["encoder"])
self.classifier = nn.Sequential(
nn.Linear(self.encoder.config.hidden_size, 4096),
nn.ReLU(),
nn.Linear(4096, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, config["num_labels"])
)
def forward(self, input_features, decoder_input_ids):
outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids)
pooled_output = outputs['last_hidden_state'][:, 0, :]
logits = self.classifier(pooled_output)
return logits
# Prepare data function
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
# Resample audio data to 16000 Hz
audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000)
# Initialize the feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
# Use Dataset class
dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}],
text_processor=feature_extractor)
return dataset
# Prediction function
def predict(audio_data, sampling_rate, config):
input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
model = SpeechClassifier(config).to(device)
# Here we load the model from Hugging Face Hub
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
model.eval()
with torch.no_grad():
logits = model(input_features, decoder_input_ids)
predicted_ids = int(torch.argmax(logits, dim=-1))
return predicted_ids
# Gradio Interface functions
def gradio_file_interface(uploaded_file):
# Assuming the uploaded_file is a filepath (str)
with open(uploaded_file, "rb") as f:
audio_data = np.frombuffer(f.read(), np.int16)
prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
return label
def gradio_mic_interface(mic_input):
# mic_input is a tuple with sample_rate and data as entries
# (44100, array([ 0, 0, 0, ..., -153, -140, -120], dtype=int16))
prediction = predict(mic_input[1], mic_input[0], config)
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
return label
# Define the interfaces inside the Blocks context
with gr.Blocks() as demo:
# File Upload Tab
with gr.Tab("Upload File"):
gr.Interface(
fn=gradio_file_interface,
inputs=gr.Audio(sources="upload", type="filepath"), # Use filepath for uploaded audio files
outputs=gr.Textbox(label="Prediction")
)
# Mic Tab
with gr.Tab("Record Using Microphone"):
gr.Interface(
fn=gradio_mic_interface,
inputs=gr.Audio(sources="microphone", type="numpy"), # Use numpy for real-time audio like microphone
outputs=gr.Textbox(label="Prediction")
)
# Launch the demo
demo.launch()