|
import gradio as gr |
|
import torch |
|
import soundfile as sf |
|
import os |
|
import numpy as np |
|
|
|
import os |
|
import soundfile as sf |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification |
|
from collections import Counter |
|
|
|
device = torch.device("cpu") |
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) |
|
model_path = "dysarthria_classifier12.pth" |
|
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
title = "Upload an mp3 file for parkinsons detection! (Thai Language)" |
|
description = """ |
|
The model was trained on Thai audio recordings with the following sentences: \n |
|
ชาวไร่ตัดต้นสนทำท่อนซุง\n |
|
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n |
|
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n |
|
เพียงแค่ฝนตกลงที่หน้าต่างในบางครา\n |
|
“อาาาาาาาาาาา”\n |
|
“อีีีีีีีีี”\n |
|
“อาาาา” (ดังขึ้นเรื่อยๆ)\n |
|
“อาา อาาา อาาาาา”\n |
|
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px> |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def predict(file_path): |
|
max_length = 100000 |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
wav_data, _ = sf.read(file_path.name) |
|
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) |
|
|
|
input_values = inputs.input_values.squeeze(0) |
|
if max_length - input_values.shape[-1] > 0: |
|
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1) |
|
else: |
|
input_values = input_values[:max_length] |
|
input_values = input_values.unsqueeze(0).to(device) |
|
inputs = {"input_values": input_values} |
|
|
|
logits = model(**inputs).logits |
|
logits = logits.squeeze() |
|
predicted_class_id = torch.argmax(logits, dim=-1).item() |
|
|
|
return predicted_class_id |
|
gr.Interface( |
|
fn=predict, |
|
inputs="file", |
|
outputs="text", |
|
title=title, |
|
description=description, |
|
).launch() |
|
|
|
|
|
|