|
--- |
|
license: cc-by-4.0 |
|
language: |
|
- hak |
|
pipeline_tag: automatic-speech-recognition |
|
--- |
|
# Model Card for whisper-large-v3-taiwanese-hakka |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
This model is a fine-tuned version of the Taiwanese Hakka [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3), which uses the ids of each dialect as prompts during training, to experiment whether the addition of prompts to the finetune of whisper when using multiple dialects will give better results. |
|
|
|
## Dialect and Id |
|
- 四縣: htia_sixian |
|
- 海陸: htia_hailu |
|
- 大埔: htia_dapu |
|
- 饒平: htia_raoping |
|
- 詔安: htia_zhaoan |
|
- 南四縣: htia_nansixian |
|
|
|
### Training process |
|
The training of the model was performed with the following hyperparameters |
|
|
|
- Batch size: 32 |
|
- Epochs: 3 |
|
- Warmup Steps: 50 |
|
- Total Steps: 42549 |
|
- Learning rate: 7e-5 |
|
- Data augmentation: No |
|
|
|
|
|
### How to use |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
model_id = "formospeech/whisper-large-v3-taiwanese-hakka" |
|
dialect_id = "htia_sixian" |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True |
|
) |
|
model.to(device) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
max_new_tokens=128, |
|
chunk_length_s=30, |
|
batch_size=16, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
) |
|
generate_kwargs = {"language": "Chinese", "prompt_ids": torch.from_numpy(processor.get_prompt_ids(dialect_id)).to(device)} |
|
transcription = pipe("path/to/my_audio.wav", generate_kwargs=generate_kwargs) |
|
print(transcription.replace(f" {dialect_id}", "")) |
|
``` |