File size: 1,959 Bytes
3613b6b 2400474 3613b6b 6da7dff 3613b6b 2400474 705a4e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
---
language: ja
license: apache-2.0
tags:
- speech
- speaker-diarization
datasets:
- callhome
---
# Fine-tuned XLSR-53 large model for speech diarization in Japanese phone-call
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Japanese using phone-call data [CallHome](https://media.talkbank.org/ca/CallHome/jpn/).
## Usage
The model can be used directly as follows.
```python
import numpy as np
import torch
from pydub import AudioSegment
from transformers import Wav2Vec2ForAudioFrameClassification, Wav2Vec2FeatureExtractor
def _make_timegrid(sound_duration: float, total_len: int):
start_timegrid = np.linspace(0, sound_duration, total_len + 1)
dt = start_timegrid[1] - start_timegrid[0]
end_timegrid = start_timegrid + dt
return start_timegrid[:total_len], end_timegrid[:total_len]
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16_000,
padding_value=0.0,
do_normalize=True,
return_attention_mask=True,
)
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("Ivydata/wav2vec2-large-speech-diarization-jp")
filepath = "/path/to/file.wav"
sound = AudioSegment.from_file(filepath)
sound = sound.set_frame_rate(16_000)
sound_duration = sound.duration_seconds
feature = feature_extractor(np.array(sound.get_array_of_samples())).input_values[0]
input_values = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
logits = model(input_values).logits
pred = logits.argmax(dim=-1).squeeze(0)
start_timegrid, end_timegrid = _make_timegrid(sound_duration, len(pred))
print("sec speaker_label")
for p, start_time in zip(pred, start_timegrid):
print(f"{start_time:.4f} {p}")
```
## Training
The model was trained on Japanese phone-call corpus [CallHome](https://media.talkbank.org/ca/CallHome/jpn/).
## License
[The Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0)
|