lilyhof's picture
Update app.py
5ed82c5 verified
import gradio as gr
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperModel, WhisperFeatureExtractor
import datasets
from datasets import load_dataset, DatasetDict, Audio
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
import librosa
# Ensure you have the device setup (cuda or cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the config for your model
config = {"encoder": "openai/whisper-base", "num_labels": 2}
# Define data class
class SpeechInferenceDataset(Dataset):
def __init__(self, audio_data, text_processor):
self.audio_data = audio_data
self.text_processor = text_processor
def __len__(self):
return len(self.audio_data)
def __getitem__(self, index):
inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
return_tensors="pt",
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
input_features = inputs.input_features.squeeze(0)
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
return input_features, decoder_input_ids
# Define model class
class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(SpeechClassifier, self).__init__()
self.encoder = WhisperModel.from_pretrained(config["encoder"])
self.classifier = nn.Sequential(
nn.Linear(self.encoder.config.hidden_size, 4096),
nn.ReLU(),
nn.Linear(4096, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, config["num_labels"])
)
def forward(self, input_features, decoder_input_ids):
outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids)
pooled_output = outputs['last_hidden_state'][:, 0, :]
logits = self.classifier(pooled_output)
return logits
# Prepare data function
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
# Convert audio data to float32
audio_data = audio_data.astype(np.float32)
# Resample audio data to 16000 Hz
audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000)
# Initialize the feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
# Use Dataset class
dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}],
text_processor=feature_extractor)
dataloader = DataLoader(dataset, batch_size=1)
return dataloader
# return dataset
# Prediction function
def predict(audio_data, sampling_rate, config):
dataloader = prepare_data(audio_data, sampling_rate, config["encoder"])
model = SpeechClassifier(config).to(device)
# Here we load the model from Hugging Face Hub
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
model.eval()
with torch.no_grad():
for input_features, decoder_input_ids in dataloader:
input_features = input_features.to(device)
decoder_input_ids = decoder_input_ids.to(device)
logits = model(input_features, decoder_input_ids)
predicted_ids = int(torch.argmax(logits, dim=-1))
return predicted_ids
# input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
# model = SpeechClassifier(config).to(device)
# # Here we load the model from Hugging Face Hub
# model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
# model.eval()
# with torch.no_grad():
# logits = model(input_features, decoder_input_ids)
# predicted_ids = int(torch.argmax(logits, dim=-1))
# return predicted_ids
# Gradio Interface functions
def gradio_file_interface(uploaded_file):
# Assuming the uploaded_file is a filepath (str)
audio_data, sampling_rate = librosa.load(uploaded_file, sr=None)
prediction = predict(audio_data, sampling_rate, config)
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
return label
# with open(uploaded_file, "rb") as f:
# audio_data = np.frombuffer(f.read(), np.int16)
# prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
# label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
# return label
def gradio_mic_interface(mic_input):
# mic_input is a tuple with sample_rate and data as entries
# (44100, array([ 0, 0, 0, ..., -153, -140, -120], dtype=int16))
prediction = predict(mic_input[1].astype(np.float32), mic_input[0], config)
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
return label
# Define the interfaces inside the Blocks context
with gr.Blocks() as demo:
# File Upload Tab
with gr.Tab("Upload File"):
gr.Interface(
fn=gradio_file_interface,
inputs=gr.Audio(sources="upload", type="filepath"), # Use filepath for uploaded audio files
outputs=gr.Textbox(label="Prediction")
)
# Mic Tab
with gr.Tab("Record Using Microphone"):
gr.Interface(
fn=gradio_mic_interface,
inputs=gr.Audio(sources="microphone", type="numpy"), # Use numpy for real-time audio like microphone
outputs=gr.Textbox(label="Prediction")
)
# Launch the demo with debugging enabled
demo.launch(debug=True)