Add code example
Browse files
README.md
CHANGED
@@ -69,3 +69,51 @@ torchrun --standalone --nnodes=1 --nproc-per-node=2 ../train_w2v2_bert.py \
|
|
69 |
--mask_feature_prob 0.0 \
|
70 |
--mask_feature_length 10
|
71 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
--mask_feature_prob 0.0 \
|
70 |
--mask_feature_length 10
|
71 |
```
|
72 |
+
|
73 |
+
## Usage
|
74 |
+
|
75 |
+
```python
|
76 |
+
# pip install -U torch soundfile transformers
|
77 |
+
|
78 |
+
import torch
|
79 |
+
import soundfile as sf
|
80 |
+
import evaluate
|
81 |
+
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
|
82 |
+
|
83 |
+
# Config
|
84 |
+
model_name = 'Yehor/w2v-bert-2.0-uk'
|
85 |
+
device = 'cuda:1' # or cpu
|
86 |
+
sampling_rate = 16_000
|
87 |
+
|
88 |
+
# Load the model
|
89 |
+
asr_model = AutoModelForCTC.from_pretrained(model_name).to(device)
|
90 |
+
processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
|
91 |
+
|
92 |
+
paths = [
|
93 |
+
'sample1.wav',
|
94 |
+
]
|
95 |
+
|
96 |
+
# Extract audio
|
97 |
+
audio_inputs = []
|
98 |
+
for path in paths:
|
99 |
+
audio_input, _ = sf.read(path)
|
100 |
+
audio_inputs.append(audio_input)
|
101 |
+
|
102 |
+
# Transcribe the audio
|
103 |
+
inputs = processor(audio_inputs, sampling_rate=sampling_rate).input_features
|
104 |
+
features = torch.tensor(inputs).to(device)
|
105 |
+
|
106 |
+
with torch.no_grad():
|
107 |
+
logits = asr_model(features).logits
|
108 |
+
|
109 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
110 |
+
predictions = processor.batch_decode(predicted_ids)
|
111 |
+
|
112 |
+
# Log outputs
|
113 |
+
print('---')
|
114 |
+
print('Predictions:')
|
115 |
+
print(predictions)
|
116 |
+
print('References:')
|
117 |
+
print(references)
|
118 |
+
print('---')
|
119 |
+
```
|