3loi commited on
Commit
c3d029a
1 Parent(s): f9b0e0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -8,13 +8,14 @@ import numpy as np
8
 
9
 
10
 
11
-
 
12
 
13
 
14
 
15
  def classify_audio(audio_file):
16
  model = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes", trust_remote_code=True)
17
- mean, std = -8.278621631819787e-05, 0.08485510250851999
18
 
19
  sr, raw_wav = audio_file
20
 
@@ -30,8 +31,14 @@ def classify_audio(audio_file):
30
  wavs = torch.tensor(norm_wav).unsqueeze(0)
31
 
32
  pred = model(wavs, mask).detach().numpy()
33
- print(str(pred))
34
- return str(pred)
 
 
 
 
 
 
35
 
36
 
37
  def main():
 
8
 
9
 
10
 
11
+ mean, std = -8.278621631819787e-05, 0.08485510250851999
12
+ id2label = {0: 'arousal', 1: 'dominance', 2: 'valence'}
13
 
14
 
15
 
16
  def classify_audio(audio_file):
17
  model = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes", trust_remote_code=True)
18
+
19
 
20
  sr, raw_wav = audio_file
21
 
 
31
  wavs = torch.tensor(norm_wav).unsqueeze(0)
32
 
33
  pred = model(wavs, mask).detach().numpy()
34
+
35
+ pred = {}
36
+ for i, audio_pred in enumerate(pred.numpy()):
37
+ pred[i] = {}
38
+ for att_i, att_val in enumerate(audio_pred):
39
+ pred[i][id2label[att_i]] = att_val
40
+
41
+ return pred
42
 
43
 
44
  def main():