File size: 4,587 Bytes
3df0c79
 
8c6f372
3df0c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
019fd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3df0c79
019fd18
3df0c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c6f372
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
---
language:
- en
license: apache-2.0
base_model: openai/whisper-small
tags:
- speaker-diarization
- speaker-segmentation
- generated_from_trainer
model-index:
- name: speaker-segmentation-eng
  results: []
---


# speaker-segmentation-eng

This model is a fine-tuned version of [openai/whisper-small](https://huggingface.co/openai/whisper-small) on the diarizers-community/callhome dataset.
It achieves the following results on the evaluation set:
- Loss: 0.4666
- Der: 0.1827
- False Alarm: 0.0590
- Missed Detection: 0.0715
- Confusion: 0.0522

## Model description

This segmentation model has been trained on English data (Callhome) using diarizers. It can be loaded with two lines of code:
```python
from diarizers import SegmentationModel

segmentation_model = SegmentationModel().from_pretrained('diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn')
```

To use it within a pyannote speaker diarization pipeline, load the [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) pipeline, and convert the model to a pyannote compatible format: 

```python

from diarizers import SegmentationModel
from pyannote.audio import Pipeline
from datasets import load_dataset
import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# load the pre-trained pyannote pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(device)

model = SegmentationModel().from_pretrained("nehulagrawal/speaker-segmentation-eng")
model = model.to_pyannote_model()
pipeline._segmentation.model = model.to(device)
```

You can now use the pipeline on audio examples: 

```python
from datasets import load_dataset
# load dataset example
dataset = load_dataset("diarizers-community/callhome", "eng", split="data")
sample = dataset[0]["audio"]

# pre-process inputs
sample["waveform"] = torch.from_numpy(sample.pop("array")[None, :]).to(device, dtype=model.dtype)
sample["sample_rate"] = sample.pop("sampling_rate")

# perform inference
diarization = pipeline(sample)

# dump the diarization output to disk using RTTM format
with open("audio.rttm", "w") as rttm:
    diarization.write_rttm(rttm)
```

You can now use the pipeline on single audio examples:

```python

from diarizers import SegmentationModel
from pyannote.audio import Pipeline
from datasets import load_dataset
import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# load the pre-trained pyannote pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(device)

model = SegmentationModel().from_pretrained("nehulagrawal/speaker-segmentation-eng")
model = model.to_pyannote_model()
pipeline._segmentation.model = model.to(device)

diarization = pipeline("audio.wav")
with open("audio.rttm", "w") as rttm:
    diarization.write_rttm(rttm)

```

## Training procedure

### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 0.001
- train_batch_size: 64
- eval_batch_size: 64
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: cosine
- num_epochs: 10

### Training results

| Training Loss | Epoch | Step | Validation Loss | Der    | False Alarm | Missed Detection | Confusion |
|:-------------:|:-----:|:----:|:---------------:|:------:|:-----------:|:----------------:|:---------:|
| 0.4224        | 1.0   | 181  | 0.4837          | 0.1939 | 0.0599      | 0.0764           | 0.0576    |
| 0.409         | 2.0   | 362  | 0.4692          | 0.1884 | 0.0618      | 0.0724           | 0.0543    |
| 0.3919        | 3.0   | 543  | 0.4700          | 0.1875 | 0.0638      | 0.0698           | 0.0540    |
| 0.3693        | 4.0   | 724  | 0.4718          | 0.1848 | 0.0602      | 0.0714           | 0.0533    |
| 0.358         | 5.0   | 905  | 0.4606          | 0.1810 | 0.0544      | 0.0754           | 0.0512    |
| 0.355         | 6.0   | 1086 | 0.4631          | 0.1826 | 0.0638      | 0.0677           | 0.0512    |
| 0.3563        | 7.0   | 1267 | 0.4646          | 0.1809 | 0.0587      | 0.0716           | 0.0505    |
| 0.347         | 8.0   | 1448 | 0.4682          | 0.1820 | 0.0581      | 0.0720           | 0.0519    |
| 0.3463        | 9.0   | 1629 | 0.4684          | 0.1827 | 0.0586      | 0.0718           | 0.0523    |
| 0.3299        | 10.0  | 1810 | 0.4666          | 0.1827 | 0.0590      | 0.0715           | 0.0522    |


### Framework versions

- Transformers 4.40.1
- Pytorch 2.2.1+cu121
- Datasets 2.19.1
- Tokenizers 0.19.1