Yehor commited on
Commit
3687bc0
1 Parent(s): 3f87c4d

Add code example

Browse files
Files changed (1) hide show
  1. README.md +48 -0
README.md CHANGED
@@ -69,3 +69,51 @@ torchrun --standalone --nnodes=1 --nproc-per-node=2 ../train_w2v2_bert.py \
69
  --mask_feature_prob 0.0 \
70
  --mask_feature_length 10
71
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  --mask_feature_prob 0.0 \
70
  --mask_feature_length 10
71
  ```
72
+
73
+ ## Usage
74
+
75
+ ```python
76
+ # pip install -U torch soundfile transformers
77
+
78
+ import torch
79
+ import soundfile as sf
80
+ import evaluate
81
+ from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
82
+
83
+ # Config
84
+ model_name = 'Yehor/w2v-bert-2.0-uk'
85
+ device = 'cuda:1' # or cpu
86
+ sampling_rate = 16_000
87
+
88
+ # Load the model
89
+ asr_model = AutoModelForCTC.from_pretrained(model_name).to(device)
90
+ processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
91
+
92
+ paths = [
93
+ 'sample1.wav',
94
+ ]
95
+
96
+ # Extract audio
97
+ audio_inputs = []
98
+ for path in paths:
99
+ audio_input, _ = sf.read(path)
100
+ audio_inputs.append(audio_input)
101
+
102
+ # Transcribe the audio
103
+ inputs = processor(audio_inputs, sampling_rate=sampling_rate).input_features
104
+ features = torch.tensor(inputs).to(device)
105
+
106
+ with torch.no_grad():
107
+ logits = asr_model(features).logits
108
+
109
+ predicted_ids = torch.argmax(logits, dim=-1)
110
+ predictions = processor.batch_decode(predicted_ids)
111
+
112
+ # Log outputs
113
+ print('---')
114
+ print('Predictions:')
115
+ print(predictions)
116
+ print('References:')
117
+ print(references)
118
+ print('---')
119
+ ```