File size: 1,959 Bytes
3613b6b
2400474
3613b6b
6da7dff
 
 
 
 
3613b6b
2400474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705a4e8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
---
language: ja
license: apache-2.0
tags:
  - speech
  - speaker-diarization
datasets:
  - callhome
---

# Fine-tuned XLSR-53 large model for speech diarization in Japanese phone-call

Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Japanese using phone-call data [CallHome](https://media.talkbank.org/ca/CallHome/jpn/).

## Usage
The model can be used directly as follows.

```python
import numpy as np
import torch
from pydub import AudioSegment

from transformers import Wav2Vec2ForAudioFrameClassification, Wav2Vec2FeatureExtractor


def _make_timegrid(sound_duration: float, total_len: int):
    start_timegrid = np.linspace(0, sound_duration, total_len + 1)
    dt = start_timegrid[1] - start_timegrid[0]
    end_timegrid = start_timegrid + dt
    return start_timegrid[:total_len], end_timegrid[:total_len]

feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1,
    sampling_rate=16_000,
    padding_value=0.0,
    do_normalize=True,
    return_attention_mask=True,
)
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("Ivydata/wav2vec2-large-speech-diarization-jp")
filepath = "/path/to/file.wav"
sound = AudioSegment.from_file(filepath)
sound = sound.set_frame_rate(16_000)
sound_duration = sound.duration_seconds

feature = feature_extractor(np.array(sound.get_array_of_samples())).input_values[0]
input_values = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)

with torch.no_grad():
    logits = model(input_values).logits
pred = logits.argmax(dim=-1).squeeze(0)
start_timegrid, end_timegrid = _make_timegrid(sound_duration, len(pred))

print("sec     speaker_label")
for p, start_time in zip(pred, start_timegrid):
    print(f"{start_time:.4f}  {p}")
```

## Training

The model was trained on Japanese phone-call corpus [CallHome](https://media.talkbank.org/ca/CallHome/jpn/).

## License

[The Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0)