File size: 4,701 Bytes
3df0c79 8c6f372 3df0c79 019fd18 3df0c79 019fd18 3df0c79 8c6f372 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
---
language:
- en
license: apache-2.0
base_model: openai/whisper-small
tags:
- speaker-diarization
- speaker-segmentation
- generated_from_trainer
model-index:
- name: speaker-segmentation-eng
results: []
---
# speaker-segmentation-eng
This model is a fine-tuned version of [openai/whisper-small](https://huggingface.co/openai/whisper-small) on the diarizers-community/callhome dataset.
It achieves the following results on the evaluation set:
- Loss: 0.4666
- Der: 0.1827
- False Alarm: 0.0590
- Missed Detection: 0.0715
- Confusion: 0.0522
## Model description
This segmentation model has been trained on English data (Callhome) using diarizers. It can be loaded with two lines of code:
```python
from diarizers import SegmentationModel
segmentation_model = SegmentationModel().from_pretrained('diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn')
```
To use it within a pyannote speaker diarization pipeline, load the [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) pipeline, and convert the model to a pyannote compatible format:
```python
from diarizers import SegmentationModel
from pyannote.audio import Pipeline
from datasets import load_dataset
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# load the pre-trained pyannote pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(device)
model = SegmentationModel().from_pretrained("nehulagrawal/speaker-segmentation-eng")
model = model.to_pyannote_model()
pipeline._segmentation.model = model.to(device)
```
You can now use the pipeline on audio examples:
```python
from datasets import load_dataset
# load dataset example
dataset = load_dataset("diarizers-community/callhome", "eng", split="data")
sample = dataset[0]["audio"]
# pre-process inputs
sample["waveform"] = torch.from_numpy(sample.pop("array")[None, :]).to(device, dtype=model.dtype)
sample["sample_rate"] = sample.pop("sampling_rate")
# perform inference
diarization = pipeline(sample)
# dump the diarization output to disk using RTTM format
with open("audio.rttm", "w") as rttm:
diarization.write_rttm(rttm)
```
You can now use the pipeline on single audio examples:
```python
from diarizers import SegmentationModel
from pyannote.audio import Pipeline
from datasets import load_dataset
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# load the pre-trained pyannote pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(device)
model = SegmentationModel().from_pretrained("nehulagrawal/speaker-segmentation-eng")
model = model.to_pyannote_model()
pipeline._segmentation.model = model.to(device)
diarization = pipeline("audio.wav")
with open("audio.rttm", "w") as rttm:
diarization.write_rttm(rttm)
```
## Intended uses & limitations
More information needed
## Training and evaluation data
More information needed
## Training procedure
### Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.001
- train_batch_size: 64
- eval_batch_size: 64
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: cosine
- num_epochs: 10
### Training results
| Training Loss | Epoch | Step | Validation Loss | Der | False Alarm | Missed Detection | Confusion |
|:-------------:|:-----:|:----:|:---------------:|:------:|:-----------:|:----------------:|:---------:|
| 0.4224 | 1.0 | 181 | 0.4837 | 0.1939 | 0.0599 | 0.0764 | 0.0576 |
| 0.409 | 2.0 | 362 | 0.4692 | 0.1884 | 0.0618 | 0.0724 | 0.0543 |
| 0.3919 | 3.0 | 543 | 0.4700 | 0.1875 | 0.0638 | 0.0698 | 0.0540 |
| 0.3693 | 4.0 | 724 | 0.4718 | 0.1848 | 0.0602 | 0.0714 | 0.0533 |
| 0.358 | 5.0 | 905 | 0.4606 | 0.1810 | 0.0544 | 0.0754 | 0.0512 |
| 0.355 | 6.0 | 1086 | 0.4631 | 0.1826 | 0.0638 | 0.0677 | 0.0512 |
| 0.3563 | 7.0 | 1267 | 0.4646 | 0.1809 | 0.0587 | 0.0716 | 0.0505 |
| 0.347 | 8.0 | 1448 | 0.4682 | 0.1820 | 0.0581 | 0.0720 | 0.0519 |
| 0.3463 | 9.0 | 1629 | 0.4684 | 0.1827 | 0.0586 | 0.0718 | 0.0523 |
| 0.3299 | 10.0 | 1810 | 0.4666 | 0.1827 | 0.0590 | 0.0715 | 0.0522 |
### Framework versions
- Transformers 4.40.1
- Pytorch 2.2.1+cu121
- Datasets 2.19.1
- Tokenizers 0.19.1 |