File size: 3,806 Bytes
73cab25 ba42b9f 73cab25 ba42b9f 1c411ce df4dfab e1cd816 ba42b9f 9beef86 223eb95 9beef86 aa1c032 5427df2 aa1c032 77ff1f3 33a5bcf 4345351 aa1c032 154b309 9beef86 15a6f16 9beef86 15a6f16 83f90bc 15a6f16 83f90bc 15a6f16 1ba2548 82be3cc 15a6f16 82be3cc aa1c032 82be3cc b1ac211 82be3cc 73cab25 61137cb 5914cfd 15a6f16 5914cfd 5549008 186636c 5914cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
import torch
import soundfile as sf
import os
import numpy as np
import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from collections import Counter
device = torch.device("cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
model_path = "dysarthria_classifier12.pth"
# model_path = '/home/user/app/dysarthria_classifier12.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# if os.path.exists(model_path):
# print(f"Loading saved model {model_path}")
# model.load_state_dict(torch.load(model_path))
title = "Upload an mp3 file for Psuedobulbar Palsy (PP) detection! (Thai Language)"
description = """
The model was trained on Thai audio recordings with the following sentences so please use these sentences: \n
ชาวไร่ตัดต้นสนทำท่อนซุง\n
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
เพียงแค่ฝนตกลงที่หน้าต่างในบางครา\n
“อาาาาาาาาาาา”\n
“อีีีีีีีีี”\n
“อาาาา” (ดังขึ้นเรื่อยๆ)\n
“อาา อาาา อาาาาา”\n
"""
# <iframe src="https://giphy.com/embed/g7GKcSzwQfugw" width="480" height="407" frameBorder="0" class="giphy-embed" allowFullScreen></iframe><p><a href="https://giphy.com/gifs/rick-roll-g7GKcSzwQfugw">via GIPHY</a></p>
def predict(file_upload,microphone):
max_length = 100000
file_path =file_upload
warn_output = ""
if (microphone is not None) and (file_upload is not None):
warn_output = (
"WARNING: You've uploaded an audio file and used the microphone. "
"The recorded file from the microphone will be used and the uploaded audio will be discarded.\n\n"
)
elif (microphone is None) and (file_upload is None):
return "ERROR: You have to either use the microphone or upload an audio file"
if(file_upload is not None):
file_path = file_upload
if(microphone is not None):
file_path = microphone
model.eval()
with torch.no_grad():
wav_data, _ = sf.read(file_path)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0)
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
inputs = {"input_values": input_values}
logits = model(**inputs).logits
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1).item()
return warn_output + "You probably have PP" if predicted_class_id == 1 else warn_output + "You probably don't have PP"
gr.Interface(
fn=predict,
inputs=[
gr.inputs.Audio(source="upload", type="filepath", optional=True),
gr.inputs.Audio(source="microphone", type="filepath", optional=True),
],
outputs="text",
title=title,
description=description,
).launch()
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
# iface.launch() |