--- language: ja license: apache-2.0 tags: - speech - speaker-diarization datasets: - callhome --- # Fine-tuned XLSR-53 large model for speech diarization in Japanese phone-call Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Japanese using phone-call data [CallHome](https://media.talkbank.org/ca/CallHome/jpn/). ## Usage The model can be used directly as follows. ```python import numpy as np import torch from pydub import AudioSegment from transformers import Wav2Vec2ForAudioFrameClassification, Wav2Vec2FeatureExtractor def _make_timegrid(sound_duration: float, total_len: int): start_timegrid = np.linspace(0, sound_duration, total_len + 1) dt = start_timegrid[1] - start_timegrid[0] end_timegrid = start_timegrid + dt return start_timegrid[:total_len], end_timegrid[:total_len] feature_extractor = Wav2Vec2FeatureExtractor( feature_size=1, sampling_rate=16_000, padding_value=0.0, do_normalize=True, return_attention_mask=True, ) model = Wav2Vec2ForAudioFrameClassification.from_pretrained("Ivydata/wav2vec2-large-speech-diarization-jp") filepath = "/path/to/file.wav" sound = AudioSegment.from_file(filepath) sound = sound.set_frame_rate(16_000) sound_duration = sound.duration_seconds feature = feature_extractor(np.array(sound.get_array_of_samples())).input_values[0] input_values = torch.tensor(feature, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): logits = model(input_values).logits pred = logits.argmax(dim=-1).squeeze(0) start_timegrid, end_timegrid = _make_timegrid(sound_duration, len(pred)) print("sec speaker_label") for p, start_time in zip(pred, start_timegrid): print(f"{start_time:.4f} {p}") ``` ## Training The model was trained on Japanese phone-call corpus [CallHome](https://media.talkbank.org/ca/CallHome/jpn/). ## License [The Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0)