Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,8 @@ import numpy as np
|
|
9 |
|
10 |
|
11 |
mean, std = -8.278621631819787e-05, 0.08485510250851999
|
|
|
|
|
12 |
id2label = {0: 'arousal', 1: 'dominance', 2: 'valence'}
|
13 |
description_text = "Multi-label (arousal, dominance, valence) Odyssey 2024 Emotion Recognition competition baseline model.<br> \
|
14 |
The model is trained on MSP-Podcast. \
|
@@ -28,19 +30,22 @@ def classify_audio(audio_file):
|
|
28 |
|
29 |
y = raw_wav.astype(np.float32, order='C') / np.iinfo(raw_wav.dtype).max
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
norm_wav = (y - mean) / (std+0.000001)
|
33 |
|
|
|
|
|
34 |
mask = torch.ones(1, len(norm_wav))
|
35 |
wavs = torch.tensor(norm_wav).unsqueeze(0)
|
36 |
|
37 |
pred = model(wavs, mask).detach().numpy()
|
38 |
|
39 |
-
|
40 |
-
if sr != 16000:
|
41 |
-
output += "{} sampling rate is uncompatible. The model was trained on {} sampleing rate\n".format(sr, 16000)
|
42 |
-
# for i, audio_pred in enumerate(pred):
|
43 |
-
# output[i] = {}
|
44 |
for att_i, att_val in enumerate(pred[0]):
|
45 |
output += "{}: \t{:0.4f}\n".format(id2label[att_i], att_val)
|
46 |
|
|
|
9 |
|
10 |
|
11 |
mean, std = -8.278621631819787e-05, 0.08485510250851999
|
12 |
+
model_sr=model.config.sampling_rate
|
13 |
+
|
14 |
id2label = {0: 'arousal', 1: 'dominance', 2: 'valence'}
|
15 |
description_text = "Multi-label (arousal, dominance, valence) Odyssey 2024 Emotion Recognition competition baseline model.<br> \
|
16 |
The model is trained on MSP-Podcast. \
|
|
|
30 |
|
31 |
y = raw_wav.astype(np.float32, order='C') / np.iinfo(raw_wav.dtype).max
|
32 |
|
33 |
+
|
34 |
+
|
35 |
+
output = ''
|
36 |
+
if sr != 16000:
|
37 |
+
y = librosa.resample(y, orig_sr=sr, target_sr=model_sr)
|
38 |
+
output += "{} sampling rate is uncompatible, converted to {} as the model was trained on {} sampling rate<br>".format(sr, model_sr, model_sr)
|
39 |
|
|
|
40 |
|
41 |
+
|
42 |
+
norm_wav = (y - mean) / (std+0.000001)
|
43 |
mask = torch.ones(1, len(norm_wav))
|
44 |
wavs = torch.tensor(norm_wav).unsqueeze(0)
|
45 |
|
46 |
pred = model(wavs, mask).detach().numpy()
|
47 |
|
48 |
+
|
|
|
|
|
|
|
|
|
49 |
for att_i, att_val in enumerate(pred[0]):
|
50 |
output += "{}: \t{:0.4f}\n".format(id2label[att_i], att_val)
|
51 |
|