pushing code to the hub
Browse files
README.md
CHANGED
@@ -30,7 +30,6 @@ It is based on the [mHuBERT-147](https://huggingface.co/utter-project/mHuBERT-14
|
|
30 |
|
31 |
## Training Parameters
|
32 |
The training parameters are available in config.yaml.
|
33 |
-
We downsample the commonvoice dataset to 70,000 utterances.
|
34 |
|
35 |
## ASR Model class
|
36 |
|
@@ -41,4 +40,15 @@ The code is available in [CTC_model.py](https://huggingface.co/naver/mHuBERT-147
|
|
41 |
## Running inference
|
42 |
|
43 |
The run_asr.py file illustrates how to load the model for inference (**load_asr_model**), and how to produce transcription for a file (**run_asr_inference**).
|
44 |
-
Please follow the [requirements file](https://huggingface.co/naver/mHuBERT-147-ASR-fr/blob/main/requirements.txt) to avoid incorrect model loading.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
## Training Parameters
|
32 |
The training parameters are available in config.yaml.
|
|
|
33 |
|
34 |
## ASR Model class
|
35 |
|
|
|
40 |
## Running inference
|
41 |
|
42 |
The run_asr.py file illustrates how to load the model for inference (**load_asr_model**), and how to produce transcription for a file (**run_asr_inference**).
|
43 |
+
Please follow the [requirements file](https://huggingface.co/naver/mHuBERT-147-ASR-fr/blob/main/requirements.txt) to avoid incorrect model loading.
|
44 |
+
|
45 |
+
Here is a simple example of the inference loop. Please notice that the sampling rate must be 16,000Hz.
|
46 |
+
|
47 |
+
```
|
48 |
+
from inference_code.run_inference import load_asr_model, run_asr_inference
|
49 |
+
|
50 |
+
model, processor = load_asr_model()
|
51 |
+
|
52 |
+
prediction = run_inference(model, processor, your_audio_file)
|
53 |
+
|
54 |
+
```
|
config.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group_by_length: True
|
2 |
+
evaluation_strategy: "steps"
|
3 |
+
num_train_epochs: 100
|
4 |
+
fp16: False
|
5 |
+
gradient_checkpointing: True
|
6 |
+
eval_steps: 10000
|
7 |
+
save_steps: 10000
|
8 |
+
logging_steps: 10000
|
9 |
+
learning_rate: 1e-4
|
10 |
+
adam_beta1: 0.9
|
11 |
+
adam_beta2: 0.98
|
12 |
+
adam_epsilon: 1e-08
|
13 |
+
warmup_ratio: 0.2
|
14 |
+
save_total_limit: 4
|
15 |
+
load_best_model_at_end: True
|
16 |
+
per_device_train_batch_size: 8
|
17 |
+
per_device_eval_batch_size: 2
|
18 |
+
metric_for_best_model: "cer"
|
19 |
+
greater_is_better: False
|
20 |
+
gradient_accumulation_steps: 8
|
21 |
+
final_dropout: 0.3
|
22 |
+
seed: 3452
|
23 |
+
add_interface_layer: True
|
24 |
+
num_interface_layers: 3
|
CTC_model.py → inference_code/CTC_model.py
RENAMED
File without changes
|
inference_code/__pycache__/CTC_model.cpython-39.pyc
ADDED
Binary file (3.47 kB). View file
|
|
inference_code/__pycache__/run_inference.cpython-39.pyc
ADDED
Binary file (2.21 kB). View file
|
|
inference_code/run_inference.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Inference main class.
|
3 |
+
|
4 |
+
Author: Marcely Zanon Boito, 2024
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .CTC_model import mHubertForCTC
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
11 |
+
from transformers import HubertConfig
|
12 |
+
|
13 |
+
from datasets import load_dataset
|
14 |
+
|
15 |
+
fbk_test_id = 'FBK-MT/Speech-MASSIVE-test'
|
16 |
+
mhubert_id = 'utter-project/mHuBERT-147'
|
17 |
+
|
18 |
+
def load_asr_model():
|
19 |
+
def init_config():
|
20 |
+
config = HubertConfig.from_pretrained(mhubert_id)
|
21 |
+
config.pad_token_id = processor.tokenizer.pad_token_id
|
22 |
+
config.ctc_token_id = processor.tokenizer.convert_tokens_to_ids('[CTC]')
|
23 |
+
config.vocab_size = len(processor.tokenizer)
|
24 |
+
|
25 |
+
config.output_hidden_states = False
|
26 |
+
config.add_interface = True
|
27 |
+
config.num_interface_layers = 3
|
28 |
+
return config
|
29 |
+
|
30 |
+
# Load the ASR model
|
31 |
+
tokenizer = Wav2Vec2CTCTokenizer('vocab.json', unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
|
32 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(mhubert_id)
|
33 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
34 |
+
|
35 |
+
config = init_config()
|
36 |
+
model = mHubertForCTC.from_pretrained("naver/mHuBERT-147-ASR-fr", config=config)
|
37 |
+
model.eval()
|
38 |
+
return model, processor
|
39 |
+
|
40 |
+
def run_asr_inference(model, processor, example):
|
41 |
+
audio = processor(example["array"], sampling_rate=example["sampling_rate"]).input_values[0]
|
42 |
+
input_values = torch.tensor(audio).unsqueeze(0)
|
43 |
+
|
44 |
+
with torch.no_grad():
|
45 |
+
logits = model(input_values).logits
|
46 |
+
|
47 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
48 |
+
|
49 |
+
prediction = processor.batch_decode(pred_ids)[0].replace('[CTC]', "")
|
50 |
+
return prediction
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
|
54 |
+
# Load the dataset in streaming mode
|
55 |
+
dataset = load_dataset(fbk_test_id, 'fr-FR', streaming=True)
|
56 |
+
dataset = dataset['test']
|
57 |
+
generator = iter(dataset)
|
58 |
+
|
59 |
+
# load model
|
60 |
+
model, processor = load_asr_model()
|
61 |
+
print(model)
|
62 |
+
|
63 |
+
# decode 10 examples from speech-MASSIVE
|
64 |
+
num_examples= 10
|
65 |
+
while num_examples >= 0:
|
66 |
+
example = next(generator)
|
67 |
+
|
68 |
+
prediction = run_inference(model, processor, example['audio'])
|
69 |
+
|
70 |
+
gold_standard = example['utt']
|
71 |
+
|
72 |
+
print("Gold standard:", gold_standard)
|
73 |
+
print("Prediction:", prediction)
|
74 |
+
print()
|
75 |
+
num_examples-=1
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|