File size: 2,789 Bytes
73cab25 ba42b9f 73cab25 ba42b9f 1c411ce df4dfab e1cd816 ba42b9f 9beef86 223eb95 9beef86 aa1c032 9beef86 aa1c032 9beef86 33a5bcf aa1c032 9beef86 82be3cc aa1c032 82be3cc b1ac211 82be3cc 73cab25 50f2862 5914cfd 9beef86 5914cfd 5549008 5914cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
import torch
import soundfile as sf
import os
import numpy as np
import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from collections import Counter
device = torch.device("cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
model_path = "dysarthria_classifier12.pth"
# model_path = '/home/user/app/dysarthria_classifier12.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# if os.path.exists(model_path):
# print(f"Loading saved model {model_path}")
# model.load_state_dict(torch.load(model_path))
title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
description = """
The model was trained on Thai audio recordings with the following sentences: \n
ชาวไร่ตัดต้นสนทำท่อนซุง\n
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
เพียงแค่ฝนตกลงที่หน้าต่างในบางครา\n
“อาาาาาาาาาาา”\n
“อีีีีีีีีี”\n
“อาาาา” (ดังขึ้นเรื่อยๆ)\n
“อาา อาาา อาาาาา”\n
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
"""
def predict(file_path):
max_length = 100000
model.eval()
with torch.no_grad():
wav_data, _ = sf.read(file_path.name)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0)
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
inputs = {"input_values": input_values}
logits = model(**inputs).logits
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1).item()
return predicted_class_id
gr.Interface(
fn=predict,
inputs="file",
outputs="text",
title=title,
description=description,
).launch()
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
# iface.launch() |