dkounadis commited on
Commit
6a3daac
1 Parent(s): e2d1e1b

do not apply clip()

Browse files
Files changed (1) hide show
  1. README.md +18 -18
README.md CHANGED
@@ -41,22 +41,23 @@ Florian Eyben, Felix Burkhardt, Björn Schuller.
41
 
42
 
43
 
44
- # Usage
45
  ```python
46
- from transformers import AutoModelForAudioClassification
47
- from transformers.models.wav2vec2.modeling_wav2vec2 import (
48
- Wav2Vec2Model,
49
- Wav2Vec2PreTrainedModel
50
- )
51
  import torch
52
  import types
53
  import torch.nn as nn
 
 
 
 
54
 
55
- signal = torch.rand((1, 16000)) # audio signal 16 KHz
 
 
56
  device = 'cpu'
57
 
58
- class RegressionHead(nn.Module):
59
- r"""A/D/V"""
60
 
61
  def __init__(self, config):
62
 
@@ -81,14 +82,15 @@ class Dawn(Wav2Vec2PreTrainedModel):
81
  super().__init__(config)
82
 
83
  self.wav2vec2 = Wav2Vec2Model(config)
84
- self.classifier = RegressionHead(config)
85
 
86
  def forward(self, x):
87
  '''x: (batch, audio-samples-16KHz)'''
88
- x = x - x.mean(1, keepdim=True)
89
  variance = (x * x).mean(1, keepdim=True) + 1e-7
90
- out = self.wav2vec2(x / variance.sqrt())
91
- return self.classifier(out[0].mean(1)).clip(0, 1)
 
92
 
93
 
94
  def _infer(self, x):
@@ -125,9 +127,7 @@ dawn = Dawn.from_pretrained(
125
  def wav2small(x):
126
  return .5 * dawn(x) + .5 * base(x)
127
 
128
-
129
- with torch.no_grad():
130
- pred = wav2small(signal.to(device))
131
- print(f'\nArousal = {pred[0, 0]} Dominance = {pred[0, 1]}',
132
- f' Valence = {pred[0, 2]}')
133
  ```
 
41
 
42
 
43
 
44
+ # How To
45
  ```python
46
+ import librosa
 
 
 
 
47
  import torch
48
  import types
49
  import torch.nn as nn
50
+ from transformers import AutoModelForAudioClassification
51
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
52
+ Wav2Vec2Model,
53
+ Wav2Vec2PreTrainedModel)
54
 
55
+
56
+ signal = torch.from_numpy(
57
+ librosa.load('test.wav', sr=16000)[0])[None, :]
58
  device = 'cpu'
59
 
60
+ class ADV(nn.Module):
 
61
 
62
  def __init__(self, config):
63
 
 
82
  super().__init__(config)
83
 
84
  self.wav2vec2 = Wav2Vec2Model(config)
85
+ self.classifier = ADV(config)
86
 
87
  def forward(self, x):
88
  '''x: (batch, audio-samples-16KHz)'''
89
+ x -= x.mean(1, keepdim=True)
90
  variance = (x * x).mean(1, keepdim=True) + 1e-7
91
+ x = self.wav2vec2(x / variance.sqrt()
92
+ ).last_hidden_state
93
+ return self.classifier(x.mean(1))
94
 
95
 
96
  def _infer(self, x):
 
127
  def wav2small(x):
128
  return .5 * dawn(x) + .5 * base(x)
129
 
130
+ pred = wav2small(signal.to(device))
131
+ print(f'\nArousal = {pred[:, 0]} Dominance = {pred[:, 1]}',
132
+ f' Valence = {pred[:, 2]}')
 
 
133
  ```