Update app.py
Browse files
app.py
CHANGED
@@ -16,18 +16,18 @@ from collections import Counter
|
|
16 |
device = torch.device("cpu")
|
17 |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
18 |
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
|
19 |
-
|
20 |
-
# model_path = '
|
21 |
-
model_path =
|
22 |
|
23 |
-
if os.path.exists(model_path):
|
24 |
-
|
25 |
-
|
26 |
|
27 |
|
28 |
-
title = "Upload an mp3 file for
|
29 |
description = """
|
30 |
-
The model was trained on Thai audio recordings with the following sentences
|
31 |
ชาวไร่ตัดต้นสนทำท่อนซุง\n
|
32 |
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
|
33 |
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
|
@@ -39,7 +39,13 @@ The model was trained on Thai audio recordings with the following sentences, so
|
|
39 |
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
|
40 |
"""
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
model.eval()
|
44 |
with torch.no_grad():
|
45 |
wav_data, _ = sf.read(file_path.name)
|
@@ -56,44 +62,15 @@ def actualpredict(file_path):
|
|
56 |
logits = model(**inputs).logits
|
57 |
logits = logits.squeeze()
|
58 |
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
59 |
-
return predicted_class_id
|
60 |
-
|
61 |
-
|
62 |
-
def predict(file_upload):
|
63 |
-
|
64 |
-
max_length = 100000
|
65 |
-
warn_output = " "
|
66 |
-
ans = " "
|
67 |
-
# file_path = file_upload
|
68 |
-
# if (microphone is not None) and (file_upload is not None):
|
69 |
-
# warn_output = (
|
70 |
-
# "WARNING: You've uploaded an audio file and used the microphone. "
|
71 |
-
# "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
|
72 |
-
# )
|
73 |
-
|
74 |
-
# elif (microphone is None) and (file_upload is None):
|
75 |
-
# return "ERROR: You have to either use the microphone or upload an audio file"
|
76 |
-
# if(microphone is not None):
|
77 |
-
# file_path = microphone
|
78 |
-
# if(file_upload is not None):
|
79 |
-
# file_path = file_upload
|
80 |
|
81 |
-
predicted_class_id = actualpredict(file_upload)
|
82 |
-
if(predicted_class_id==0):
|
83 |
-
ans = "no_parkinson"
|
84 |
-
else:
|
85 |
-
ans = "parkinson"
|
86 |
return predicted_class_id
|
87 |
gr.Interface(
|
88 |
fn=predict,
|
89 |
-
inputs=
|
90 |
-
gr.inputs.Audio(source="upload", type="filepath", optional=True),
|
91 |
-
],
|
92 |
outputs="text",
|
93 |
title=title,
|
94 |
description=description,
|
95 |
).launch()
|
96 |
|
97 |
-
# gr.inputs.Audio(source="microphone", type="filepath", optional=True),
|
98 |
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
|
99 |
# iface.launch()
|
|
|
16 |
device = torch.device("cpu")
|
17 |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
18 |
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
|
19 |
+
model_path = "dysarthria_classifier12.pth"
|
20 |
+
# model_path = '/home/user/app/dysarthria_classifier12.pth'
|
21 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
22 |
|
23 |
+
# if os.path.exists(model_path):
|
24 |
+
# print(f"Loading saved model {model_path}")
|
25 |
+
# model.load_state_dict(torch.load(model_path))
|
26 |
|
27 |
|
28 |
+
title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
|
29 |
description = """
|
30 |
+
The model was trained on Thai audio recordings with the following sentences: \n
|
31 |
ชาวไร่ตัดต้นสนทำท่อนซุง\n
|
32 |
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
|
33 |
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
|
|
|
39 |
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
|
40 |
"""
|
41 |
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def predict(file_path):
|
47 |
+
max_length = 100000
|
48 |
+
|
49 |
model.eval()
|
50 |
with torch.no_grad():
|
51 |
wav_data, _ = sf.read(file_path.name)
|
|
|
62 |
logits = model(**inputs).logits
|
63 |
logits = logits.squeeze()
|
64 |
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
return predicted_class_id
|
67 |
gr.Interface(
|
68 |
fn=predict,
|
69 |
+
inputs="file",
|
|
|
|
|
70 |
outputs="text",
|
71 |
title=title,
|
72 |
description=description,
|
73 |
).launch()
|
74 |
|
|
|
75 |
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
|
76 |
# iface.launch()
|