frankenjoe commited on
Commit
4ca0364
1 Parent(s): 5ad7ffe

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -92,7 +92,7 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
92
  hidden_states = outputs[0]
93
  hidden_states = torch.mean(hidden_states, dim=1)
94
  logits_age = self.age(hidden_states)
95
- logits_gender = self.gender(hidden_states)
96
 
97
  return hidden_states, logits_age, logits_gender
98
 
@@ -140,10 +140,10 @@ def process_func(
140
 
141
  print(process_func(signal, sampling_rate))
142
  # Age child female male
143
- # [[ 0.3079211 -1.6096017 -2.1094327 3.1461434]]
144
 
145
  print(process_func(signal, sampling_rate, embeddings=True))
146
  # Pooled hidden states of last transformer layer
147
- # [[-0.00752167 0.0065819 -0.00746342 ... 0.00663632 0.00848748
148
- # 0.00599211]]
149
  ```
 
92
  hidden_states = outputs[0]
93
  hidden_states = torch.mean(hidden_states, dim=1)
94
  logits_age = self.age(hidden_states)
95
+ logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
96
 
97
  return hidden_states, logits_age, logits_gender
98
 
 
140
 
141
  print(process_func(signal, sampling_rate))
142
  # Age child female male
143
+ # [[ 0.3079211 0.00848487 0.0051472 0.9863679 ]]
144
 
145
  print(process_func(signal, sampling_rate, embeddings=True))
146
  # Pooled hidden states of last transformer layer
147
+ # [[ 0.00409924 0.00365688 0.02392936 ... 0.02349018 -0.13294911
148
+ # 0.1538802 ]]
149
  ```