Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,13 +8,14 @@ import numpy as np
|
|
8 |
|
9 |
|
10 |
|
11 |
-
|
|
|
12 |
|
13 |
|
14 |
|
15 |
def classify_audio(audio_file):
|
16 |
model = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes", trust_remote_code=True)
|
17 |
-
|
18 |
|
19 |
sr, raw_wav = audio_file
|
20 |
|
@@ -30,8 +31,14 @@ def classify_audio(audio_file):
|
|
30 |
wavs = torch.tensor(norm_wav).unsqueeze(0)
|
31 |
|
32 |
pred = model(wavs, mask).detach().numpy()
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
def main():
|
|
|
8 |
|
9 |
|
10 |
|
11 |
+
mean, std = -8.278621631819787e-05, 0.08485510250851999
|
12 |
+
id2label = {0: 'arousal', 1: 'dominance', 2: 'valence'}
|
13 |
|
14 |
|
15 |
|
16 |
def classify_audio(audio_file):
|
17 |
model = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes", trust_remote_code=True)
|
18 |
+
|
19 |
|
20 |
sr, raw_wav = audio_file
|
21 |
|
|
|
31 |
wavs = torch.tensor(norm_wav).unsqueeze(0)
|
32 |
|
33 |
pred = model(wavs, mask).detach().numpy()
|
34 |
+
|
35 |
+
pred = {}
|
36 |
+
for i, audio_pred in enumerate(pred.numpy()):
|
37 |
+
pred[i] = {}
|
38 |
+
for att_i, att_val in enumerate(audio_pred):
|
39 |
+
pred[i][id2label[att_i]] = att_val
|
40 |
+
|
41 |
+
return pred
|
42 |
|
43 |
|
44 |
def main():
|