Spaces:
Sleeping
Sleeping
Irsh Vijayvargia
commited on
Commit
•
42a4544
1
Parent(s):
3e34e2e
Add application file
Browse files- app.py +132 -0
- config/config.yaml +40 -0
- gradio.ipynb +292 -0
- requirements.txt +8 -0
- speech_id_checkpoint/saved_02.model +3 -0
- utils/.ipynb_checkpoints/VAD_segments-checkpoint.py +153 -0
- utils/.ipynb_checkpoints/__init__-checkpoint.py +0 -0
- utils/.ipynb_checkpoints/data_load-checkpoint.py +57 -0
- utils/.ipynb_checkpoints/evaluation-checkpoint.py +192 -0
- utils/.ipynb_checkpoints/hparam-checkpoint.py +59 -0
- utils/.ipynb_checkpoints/kan-checkpoint.py +285 -0
- utils/.ipynb_checkpoints/speech_embedder_net-checkpoint.py +112 -0
- utils/.ipynb_checkpoints/utils-checkpoint.py +173 -0
- utils/VAD_segments.py +153 -0
- utils/__init__.py +0 -0
- utils/__pycache__/VAD_segments.cpython-39.pyc +0 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/data_load.cpython-39.pyc +0 -0
- utils/__pycache__/evaluation.cpython-39.pyc +0 -0
- utils/__pycache__/hparam.cpython-39.pyc +0 -0
- utils/__pycache__/kan.cpython-39.pyc +0 -0
- utils/__pycache__/speech_embedder_net.cpython-39.pyc +0 -0
- utils/__pycache__/utils.cpython-39.pyc +0 -0
- utils/data_load.py +57 -0
- utils/evaluation.py +192 -0
- utils/hparam.py +59 -0
- utils/kan.py +285 -0
- utils/speech_embedder_net.py +112 -0
- utils/utils.py +173 -0
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import librosa
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import webrtcvad
|
6 |
+
import wave
|
7 |
+
import contextlib
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from utils.VAD_segments import *
|
11 |
+
from utils.hparam import hparam as hp
|
12 |
+
from utils.speech_embedder_net import *
|
13 |
+
from utils.evaluation import *
|
14 |
+
|
15 |
+
def read_wave(audio_data):
|
16 |
+
"""Reads audio data and returns (PCM audio data, sample rate).
|
17 |
+
Assumes the input is a tuple (sample_rate, numpy_array).
|
18 |
+
If the sample rate is unsupported, resamples to 16000 Hz.
|
19 |
+
"""
|
20 |
+
sample_rate, data = audio_data
|
21 |
+
|
22 |
+
# Ensure data is in the correct shape
|
23 |
+
assert len(data.shape) == 1, "Audio data must be a 1D array"
|
24 |
+
|
25 |
+
# Convert to floating point if necessary
|
26 |
+
if not np.issubdtype(data.dtype, np.floating):
|
27 |
+
data = data.astype(np.float32) / np.iinfo(data.dtype).max
|
28 |
+
|
29 |
+
# Supported sample rates
|
30 |
+
supported_sample_rates = (8000, 16000, 32000, 48000)
|
31 |
+
|
32 |
+
# If sample rate is not supported, resample to 16000 Hz
|
33 |
+
if sample_rate not in supported_sample_rates:
|
34 |
+
data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000)
|
35 |
+
sample_rate = 16000
|
36 |
+
|
37 |
+
# Convert numpy array to PCM format
|
38 |
+
pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes()
|
39 |
+
|
40 |
+
return data, pcm_data
|
41 |
+
|
42 |
+
|
43 |
+
def VAD_chunk(aggressiveness, data):
|
44 |
+
audio, byte_audio = read_wave(data)
|
45 |
+
vad = webrtcvad.Vad(int(aggressiveness))
|
46 |
+
frames = frame_generator(20, byte_audio, hp.data.sr)
|
47 |
+
frames = list(frames)
|
48 |
+
times = vad_collector(hp.data.sr, 20, 200, vad, frames)
|
49 |
+
speech_times = []
|
50 |
+
speech_segs = []
|
51 |
+
for i, time in enumerate(times):
|
52 |
+
start = np.round(time[0],decimals=2)
|
53 |
+
end = np.round(time[1],decimals=2)
|
54 |
+
j = start
|
55 |
+
while j + .4 < end:
|
56 |
+
end_j = np.round(j+.4,decimals=2)
|
57 |
+
speech_times.append((j, end_j))
|
58 |
+
speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])
|
59 |
+
j = end_j
|
60 |
+
else:
|
61 |
+
speech_times.append((j, end))
|
62 |
+
speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])
|
63 |
+
return speech_times, speech_segs
|
64 |
+
|
65 |
+
|
66 |
+
def get_embedding(data, embedder_net, device, n_threshold=-1):
|
67 |
+
times, segs = VAD_chunk(0, data)
|
68 |
+
if not segs:
|
69 |
+
print(f'No voice activity detected')
|
70 |
+
return None
|
71 |
+
concat_seg = concat_segs(times, segs)
|
72 |
+
if not concat_seg:
|
73 |
+
print(f'No concatenated segments')
|
74 |
+
return None
|
75 |
+
STFT_frames = get_STFTs(concat_seg)
|
76 |
+
if not STFT_frames:
|
77 |
+
#print(f'No STFT frames')
|
78 |
+
return None
|
79 |
+
STFT_frames = np.stack(STFT_frames, axis=2)
|
80 |
+
STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
embeddings = embedder_net(STFT_frames)
|
84 |
+
embeddings = embeddings[:n_threshold, :]
|
85 |
+
|
86 |
+
avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()
|
87 |
+
return avg_embedding
|
88 |
+
|
89 |
+
|
90 |
+
model_path = "./speech_id_checkpoint/saved_02.model"
|
91 |
+
|
92 |
+
|
93 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
94 |
+
|
95 |
+
embedder_net = SpeechEmbedder().to(device)
|
96 |
+
embedder_net.load_state_dict(torch.load(model_path, map_location=device))
|
97 |
+
embedder_net.eval()
|
98 |
+
|
99 |
+
def process_audio(audio1, audio2, threshold):
|
100 |
+
e1 = get_embedding(audio1, embedder_net, device)
|
101 |
+
if(e1 is None):
|
102 |
+
return "No Voice Detected in file 1"
|
103 |
+
e2 = get_embedding(audio2, embedder_net, device)
|
104 |
+
if(e2 is None):
|
105 |
+
return "No Voice Detected in file 2"
|
106 |
+
|
107 |
+
cosi = cosine_similarity(e1, e2)
|
108 |
+
|
109 |
+
if(cosi > threshold):
|
110 |
+
return f"Same Speaker"
|
111 |
+
else:
|
112 |
+
return f"Different Speaker"
|
113 |
+
|
114 |
+
# Define the Gradio interface
|
115 |
+
def gradio_interface(audio1, audio2, threshold):
|
116 |
+
output_text = process_audio(audio1, audio2, threshold)
|
117 |
+
return output_text
|
118 |
+
|
119 |
+
# Create the Gradio interface with microphone inputs
|
120 |
+
iface = gr.Interface(
|
121 |
+
fn=gradio_interface,
|
122 |
+
inputs=[gr.Audio("microphone", type="numpy", label="Audio File 1"),
|
123 |
+
gr.Audio("microphone", type="numpy", label="Audio File 2"),
|
124 |
+
gr.Slider(0.0, 1.0, value=0.85, step=0.01, label="Threshold")
|
125 |
+
],
|
126 |
+
outputs="text",
|
127 |
+
title="Gujarati Text Independent Speaker Verification",
|
128 |
+
description="Record two audio files and get the text output from the model."
|
129 |
+
)
|
130 |
+
|
131 |
+
# Launch the interface
|
132 |
+
iface.launch(share=False)
|
config/config.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
training: !!bool "false"
|
2 |
+
device: "mps"
|
3 |
+
unprocessed_data: './DATA_DIR/*/*.wav'
|
4 |
+
---
|
5 |
+
data:
|
6 |
+
train_path: './train_tisv'
|
7 |
+
train_path_unprocessed: './TIMIT/TRAIN/*/*/*.wav'
|
8 |
+
test_path: './test_tisv'
|
9 |
+
test_path_unprocessed: './TIMIT/TEST/*/*/*.wav'
|
10 |
+
data_preprocessed: !!bool "true"
|
11 |
+
sr: 16000
|
12 |
+
nfft: 512 #For mel spectrogram preprocess
|
13 |
+
window: 0.025 #(s)
|
14 |
+
hop: 0.01 #(s)
|
15 |
+
nmels: 40 #Number of mel energies
|
16 |
+
tisv_frame: 180 #Max number of time steps in input after preprocess
|
17 |
+
---
|
18 |
+
model:
|
19 |
+
hidden: 768 #Number of LSTM hidden layer units
|
20 |
+
num_layer: 3 #Number of LSTM layers
|
21 |
+
proj: 256 #Embedding size
|
22 |
+
model_path: './speech_id_checkpoint/ckpt_epoch_840_batch_id_6.pth' #Model path for testing, inference, or resuming training
|
23 |
+
---
|
24 |
+
train:
|
25 |
+
N : 4 #Number of speakers in batch
|
26 |
+
M : 6 #Number of utterances per speaker
|
27 |
+
num_workers: 0 #number of workers for dataloader
|
28 |
+
lr: 0.01
|
29 |
+
epochs: 1000 #Max training speaker epoch
|
30 |
+
log_interval: 30 #Epochs before printing progress
|
31 |
+
log_file: './speech_id_checkpoint/Stats'
|
32 |
+
checkpoint_interval: 100 #Save model after x speaker epochs
|
33 |
+
checkpoint_dir: './speech_id_checkpoint'
|
34 |
+
restore: !!bool "true" #Resume training from previous model path
|
35 |
+
---
|
36 |
+
test:
|
37 |
+
N : 4 #Number of speakers in batch
|
38 |
+
M : 6 #Number of utterances per speaker
|
39 |
+
num_workers: 8 #number of workers for data laoder
|
40 |
+
epochs: 10 #testing speaker epochs
|
gradio.ipynb
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "23237138-936a-44b4-9eb6-f16045d2c91d",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"### **Gradio Demo | LSTM Speaker Embedding Model for Gujarati Speaker Verification**\n",
|
9 |
+
"****\n",
|
10 |
+
"**Author:** Irsh Vijay <br>\n",
|
11 |
+
"**Organization**: Wadhwani Institute for Artificial Intelligence <br>\n",
|
12 |
+
"****\n",
|
13 |
+
"This notebook has the required code to run a gradio demo."
|
14 |
+
]
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"cell_type": "code",
|
18 |
+
"execution_count": 8,
|
19 |
+
"id": "1d2cfd8b-9498-4236-9d32-718e9e0597cb",
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"import torch\n",
|
24 |
+
"import librosa\n",
|
25 |
+
"import numpy as np\n",
|
26 |
+
"import os\n",
|
27 |
+
"import webrtcvad\n",
|
28 |
+
"import wave\n",
|
29 |
+
"import contextlib\n",
|
30 |
+
"\n",
|
31 |
+
"from utils.VAD_segments import *\n",
|
32 |
+
"from utils.hparam import hparam as hp\n",
|
33 |
+
"from utils.speech_embedder_net import *\n",
|
34 |
+
"from utils.evaluation import *"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 9,
|
40 |
+
"id": "3e9e1006-83d2-4492-a210-26b2c3717cd5",
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"def read_wave(audio_data):\n",
|
45 |
+
" \"\"\"Reads audio data and returns (PCM audio data, sample rate).\n",
|
46 |
+
" Assumes the input is a tuple (sample_rate, numpy_array).\n",
|
47 |
+
" If the sample rate is unsupported, resamples to 16000 Hz.\n",
|
48 |
+
" \"\"\"\n",
|
49 |
+
" sample_rate, data = audio_data\n",
|
50 |
+
"\n",
|
51 |
+
" # Ensure data is in the correct shape\n",
|
52 |
+
" assert len(data.shape) == 1, \"Audio data must be a 1D array\"\n",
|
53 |
+
"\n",
|
54 |
+
" # Convert to floating point if necessary\n",
|
55 |
+
" if not np.issubdtype(data.dtype, np.floating):\n",
|
56 |
+
" data = data.astype(np.float32) / np.iinfo(data.dtype).max\n",
|
57 |
+
" \n",
|
58 |
+
" # Supported sample rates\n",
|
59 |
+
" supported_sample_rates = (8000, 16000, 32000, 48000)\n",
|
60 |
+
" \n",
|
61 |
+
" # If sample rate is not supported, resample to 16000 Hz\n",
|
62 |
+
" if sample_rate not in supported_sample_rates:\n",
|
63 |
+
" data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000)\n",
|
64 |
+
" sample_rate = 16000\n",
|
65 |
+
" \n",
|
66 |
+
" # Convert numpy array to PCM format\n",
|
67 |
+
" pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes()\n",
|
68 |
+
"\n",
|
69 |
+
" return data, pcm_data"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": 10,
|
75 |
+
"id": "0b56a2fc-83c3-4b36-95b8-5f1b656150ed",
|
76 |
+
"metadata": {},
|
77 |
+
"outputs": [],
|
78 |
+
"source": [
|
79 |
+
"def VAD_chunk(aggressiveness, data):\n",
|
80 |
+
" audio, byte_audio = read_wave(data)\n",
|
81 |
+
" vad = webrtcvad.Vad(int(aggressiveness))\n",
|
82 |
+
" frames = frame_generator(20, byte_audio, hp.data.sr)\n",
|
83 |
+
" frames = list(frames)\n",
|
84 |
+
" times = vad_collector(hp.data.sr, 20, 200, vad, frames)\n",
|
85 |
+
" speech_times = []\n",
|
86 |
+
" speech_segs = []\n",
|
87 |
+
" for i, time in enumerate(times):\n",
|
88 |
+
" start = np.round(time[0],decimals=2)\n",
|
89 |
+
" end = np.round(time[1],decimals=2)\n",
|
90 |
+
" j = start\n",
|
91 |
+
" while j + .4 < end:\n",
|
92 |
+
" end_j = np.round(j+.4,decimals=2)\n",
|
93 |
+
" speech_times.append((j, end_j))\n",
|
94 |
+
" speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])\n",
|
95 |
+
" j = end_j\n",
|
96 |
+
" else:\n",
|
97 |
+
" speech_times.append((j, end))\n",
|
98 |
+
" speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])\n",
|
99 |
+
" return speech_times, speech_segs"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": 11,
|
105 |
+
"id": "72f257cf-7d3f-4ec5-944a-57779ba377e6",
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [],
|
108 |
+
"source": [
|
109 |
+
"def get_embedding(data, embedder_net, device, n_threshold=-1):\n",
|
110 |
+
" times, segs = VAD_chunk(0, data)\n",
|
111 |
+
" if not segs:\n",
|
112 |
+
" print(f'No voice activity detected')\n",
|
113 |
+
" return None\n",
|
114 |
+
" concat_seg = concat_segs(times, segs)\n",
|
115 |
+
" if not concat_seg:\n",
|
116 |
+
" print(f'No concatenated segments')\n",
|
117 |
+
" return None\n",
|
118 |
+
" STFT_frames = get_STFTs(concat_seg)\n",
|
119 |
+
" if not STFT_frames:\n",
|
120 |
+
" #print(f'No STFT frames')\n",
|
121 |
+
" return None\n",
|
122 |
+
" STFT_frames = np.stack(STFT_frames, axis=2)\n",
|
123 |
+
" STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)\n",
|
124 |
+
"\n",
|
125 |
+
" with torch.no_grad():\n",
|
126 |
+
" embeddings = embedder_net(STFT_frames)\n",
|
127 |
+
" embeddings = embeddings[:n_threshold, :]\n",
|
128 |
+
" \n",
|
129 |
+
" avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()\n",
|
130 |
+
" return avg_embedding"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": 12,
|
136 |
+
"id": "200df766-407d-4367-b0fb-7a6118653731",
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"model_path = \"./speech_id_checkpoint/saved_01.model\""
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": 13,
|
146 |
+
"id": "db7613e6-67a8-4920-a999-caca4a0de360",
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [
|
149 |
+
{
|
150 |
+
"data": {
|
151 |
+
"text/plain": [
|
152 |
+
"SpeechEmbedder(\n",
|
153 |
+
" (LSTM_stack): LSTM(40, 768, num_layers=3, batch_first=True)\n",
|
154 |
+
" (projection): Linear(in_features=768, out_features=256, bias=True)\n",
|
155 |
+
")"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
"execution_count": 13,
|
159 |
+
"metadata": {},
|
160 |
+
"output_type": "execute_result"
|
161 |
+
}
|
162 |
+
],
|
163 |
+
"source": [
|
164 |
+
"device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n",
|
165 |
+
"\n",
|
166 |
+
"embedder_net = SpeechEmbedder().to(device)\n",
|
167 |
+
"embedder_net.load_state_dict(torch.load(model_path, map_location=device))\n",
|
168 |
+
"embedder_net.eval()"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": 14,
|
174 |
+
"id": "8a7dd9bd-7b40-41f9-8e2f-d68be18f2111",
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"import gradio as gr"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "code",
|
183 |
+
"execution_count": 28,
|
184 |
+
"id": "bd6c073d-eab8-4ae6-8ba6-d90a0ec54c0e",
|
185 |
+
"metadata": {},
|
186 |
+
"outputs": [
|
187 |
+
{
|
188 |
+
"name": "stdout",
|
189 |
+
"output_type": "stream",
|
190 |
+
"text": [
|
191 |
+
"Running on local URL: http://127.0.0.1:7868\n",
|
192 |
+
"\n",
|
193 |
+
"To create a public link, set `share=True` in `launch()`.\n"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"data": {
|
198 |
+
"text/html": [
|
199 |
+
"<div><iframe src=\"http://127.0.0.1:7868/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
200 |
+
],
|
201 |
+
"text/plain": [
|
202 |
+
"<IPython.core.display.HTML object>"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
"metadata": {},
|
206 |
+
"output_type": "display_data"
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"data": {
|
210 |
+
"text/plain": []
|
211 |
+
},
|
212 |
+
"execution_count": 28,
|
213 |
+
"metadata": {},
|
214 |
+
"output_type": "execute_result"
|
215 |
+
}
|
216 |
+
],
|
217 |
+
"source": [
|
218 |
+
"def process_audio(audio1, audio2, threshold):\n",
|
219 |
+
" e1 = get_embedding(audio1, embedder_net, device)\n",
|
220 |
+
" if(e1 is None):\n",
|
221 |
+
" return \"No Voice Detected in file 1\"\n",
|
222 |
+
" e2 = get_embedding(audio2, embedder_net, device)\n",
|
223 |
+
" if(e2 is None):\n",
|
224 |
+
" return \"No Voice Detected in file 2\"\n",
|
225 |
+
"\n",
|
226 |
+
" cosi = cosine_similarity(e1, e2)\n",
|
227 |
+
"\n",
|
228 |
+
" if(cosi > threshold):\n",
|
229 |
+
" return f\"Same Speaker\" \n",
|
230 |
+
" else:\n",
|
231 |
+
" return f\"Different Speaker\" \n",
|
232 |
+
"\n",
|
233 |
+
"# Define the Gradio interface\n",
|
234 |
+
"def gradio_interface(audio1, audio2, threshold):\n",
|
235 |
+
" output_text = process_audio(audio1, audio2, threshold)\n",
|
236 |
+
" return output_text\n",
|
237 |
+
"\n",
|
238 |
+
"# Create the Gradio interface with microphone inputs\n",
|
239 |
+
"iface = gr.Interface(\n",
|
240 |
+
" fn=gradio_interface,\n",
|
241 |
+
" inputs=[gr.Audio(\"microphone\", type=\"numpy\", label=\"Audio File 1\"),\n",
|
242 |
+
" gr.Audio(\"microphone\", type=\"numpy\", label=\"Audio File 2\"),\n",
|
243 |
+
" gr.Slider(0.0, 1.0, value=0.85, step=0.01, label=\"Threshold\")\n",
|
244 |
+
" ],\n",
|
245 |
+
" outputs=\"text\",\n",
|
246 |
+
" title=\"LSTM Based Speaker Verification\",\n",
|
247 |
+
" description=\"Record two audio files and get the text output from the model.\"\n",
|
248 |
+
")\n",
|
249 |
+
"\n",
|
250 |
+
"# Launch the interface\n",
|
251 |
+
"iface.launch(share=False)"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "code",
|
256 |
+
"execution_count": null,
|
257 |
+
"id": "a098495c-9e7b-4232-86fc-55a1890c5e27",
|
258 |
+
"metadata": {},
|
259 |
+
"outputs": [],
|
260 |
+
"source": []
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"cell_type": "code",
|
264 |
+
"execution_count": null,
|
265 |
+
"id": "b99a253e-9b91-4210-b934-8bd1b6a2d912",
|
266 |
+
"metadata": {},
|
267 |
+
"outputs": [],
|
268 |
+
"source": []
|
269 |
+
}
|
270 |
+
],
|
271 |
+
"metadata": {
|
272 |
+
"kernelspec": {
|
273 |
+
"display_name": "Python 3 (ipykernel)",
|
274 |
+
"language": "python",
|
275 |
+
"name": "python3"
|
276 |
+
},
|
277 |
+
"language_info": {
|
278 |
+
"codemirror_mode": {
|
279 |
+
"name": "ipython",
|
280 |
+
"version": 3
|
281 |
+
},
|
282 |
+
"file_extension": ".py",
|
283 |
+
"mimetype": "text/x-python",
|
284 |
+
"name": "python",
|
285 |
+
"nbconvert_exporter": "python",
|
286 |
+
"pygments_lexer": "ipython3",
|
287 |
+
"version": "3.9.19"
|
288 |
+
}
|
289 |
+
},
|
290 |
+
"nbformat": 4,
|
291 |
+
"nbformat_minor": 5
|
292 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
librosa
|
3 |
+
numpy
|
4 |
+
webrtcvad
|
5 |
+
wave
|
6 |
+
contextlib
|
7 |
+
gradio
|
8 |
+
PyYAML
|
speech_id_checkpoint/saved_02.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51b96ce4d80a01ebe039ed6bc67c1a9731315742d5814fed842d4a22785c5836
|
3 |
+
size 48543874
|
utils/.ipynb_checkpoints/VAD_segments-checkpoint.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Tue Dec 18 16:22:41 2018
|
5 |
+
|
6 |
+
@author: Harry
|
7 |
+
Modified from https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
8 |
+
"""
|
9 |
+
|
10 |
+
import collections
|
11 |
+
import contextlib
|
12 |
+
import numpy as np
|
13 |
+
import sys
|
14 |
+
import librosa
|
15 |
+
import wave
|
16 |
+
|
17 |
+
import webrtcvad
|
18 |
+
|
19 |
+
from utils.hparam import hparam as hp
|
20 |
+
|
21 |
+
def read_wave(path, sr):
|
22 |
+
"""Reads a .wav file.
|
23 |
+
Takes the path, and returns (PCM audio data, sample rate).
|
24 |
+
Assumes sample width == 2
|
25 |
+
"""
|
26 |
+
with contextlib.closing(wave.open(path, 'rb')) as wf:
|
27 |
+
num_channels = wf.getnchannels()
|
28 |
+
assert num_channels == 1
|
29 |
+
sample_width = wf.getsampwidth()
|
30 |
+
assert sample_width == 2
|
31 |
+
sample_rate = wf.getframerate()
|
32 |
+
assert sample_rate in (8000, 16000, 32000, 48000)
|
33 |
+
pcm_data = wf.readframes(wf.getnframes())
|
34 |
+
data, _ = librosa.load(path, sr=sr)
|
35 |
+
assert len(data.shape) == 1
|
36 |
+
assert sr in (8000, 16000, 32000, 48000)
|
37 |
+
return data, pcm_data
|
38 |
+
|
39 |
+
class Frame(object):
|
40 |
+
"""Represents a "frame" of audio data."""
|
41 |
+
def __init__(self, bytes, timestamp, duration):
|
42 |
+
self.bytes = bytes
|
43 |
+
self.timestamp = timestamp
|
44 |
+
self.duration = duration
|
45 |
+
|
46 |
+
|
47 |
+
def frame_generator(frame_duration_ms, audio, sample_rate):
|
48 |
+
"""Generates audio frames from PCM audio data.
|
49 |
+
Takes the desired frame duration in milliseconds, the PCM data, and
|
50 |
+
the sample rate.
|
51 |
+
Yields Frames of the requested duration.
|
52 |
+
"""
|
53 |
+
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
|
54 |
+
offset = 0
|
55 |
+
timestamp = 0.0
|
56 |
+
duration = (float(n) / sample_rate) / 2.0
|
57 |
+
while offset + n < len(audio):
|
58 |
+
yield Frame(audio[offset:offset + n], timestamp, duration)
|
59 |
+
timestamp += duration
|
60 |
+
offset += n
|
61 |
+
|
62 |
+
|
63 |
+
def vad_collector(sample_rate, frame_duration_ms,
|
64 |
+
padding_duration_ms, vad, frames):
|
65 |
+
"""Filters out non-voiced audio frames.
|
66 |
+
Given a webrtcvad.Vad and a source of audio frames, yields only
|
67 |
+
the voiced audio.
|
68 |
+
Uses a padded, sliding window algorithm over the audio frames.
|
69 |
+
When more than 90% of the frames in the window are voiced (as
|
70 |
+
reported by the VAD), the collector triggers and begins yielding
|
71 |
+
audio frames. Then the collector waits until 90% of the frames in
|
72 |
+
the window are unvoiced to detrigger.
|
73 |
+
The window is padded at the front and back to provide a small
|
74 |
+
amount of silence or the beginnings/endings of speech around the
|
75 |
+
voiced frames.
|
76 |
+
Arguments:
|
77 |
+
sample_rate - The audio sample rate, in Hz.
|
78 |
+
frame_duration_ms - The frame duration in milliseconds.
|
79 |
+
padding_duration_ms - The amount to pad the window, in milliseconds.
|
80 |
+
vad - An instance of webrtcvad.Vad.
|
81 |
+
frames - a source of audio frames (sequence or generator).
|
82 |
+
Returns: A generator that yields PCM audio data.
|
83 |
+
"""
|
84 |
+
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
|
85 |
+
# We use a deque for our sliding window/ring buffer.
|
86 |
+
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
87 |
+
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
|
88 |
+
# NOTTRIGGERED state.
|
89 |
+
triggered = False
|
90 |
+
|
91 |
+
voiced_frames = []
|
92 |
+
for frame in frames:
|
93 |
+
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
94 |
+
|
95 |
+
if not triggered:
|
96 |
+
ring_buffer.append((frame, is_speech))
|
97 |
+
num_voiced = len([f for f, speech in ring_buffer if speech])
|
98 |
+
# If we're NOTTRIGGERED and more than 90% of the frames in
|
99 |
+
# the ring buffer are voiced frames, then enter the
|
100 |
+
# TRIGGERED state.
|
101 |
+
if num_voiced > 0.9 * ring_buffer.maxlen:
|
102 |
+
triggered = True
|
103 |
+
start = ring_buffer[0][0].timestamp
|
104 |
+
# We want to yield all the audio we see from now until
|
105 |
+
# we are NOTTRIGGERED, but we have to start with the
|
106 |
+
# audio that's already in the ring buffer.
|
107 |
+
for f, s in ring_buffer:
|
108 |
+
voiced_frames.append(f)
|
109 |
+
ring_buffer.clear()
|
110 |
+
else:
|
111 |
+
# We're in the TRIGGERED state, so collect the audio data
|
112 |
+
# and add it to the ring buffer.
|
113 |
+
voiced_frames.append(frame)
|
114 |
+
ring_buffer.append((frame, is_speech))
|
115 |
+
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
116 |
+
# If more than 90% of the frames in the ring buffer are
|
117 |
+
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
118 |
+
# audio we've collected.
|
119 |
+
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
120 |
+
triggered = False
|
121 |
+
yield (start, frame.timestamp + frame.duration)
|
122 |
+
ring_buffer.clear()
|
123 |
+
voiced_frames = []
|
124 |
+
# If we have any leftover voiced audio when we run out of input,
|
125 |
+
# yield it.
|
126 |
+
if voiced_frames:
|
127 |
+
yield (start, frame.timestamp + frame.duration)
|
128 |
+
|
129 |
+
|
130 |
+
def VAD_chunk(aggressiveness, path):
|
131 |
+
audio, byte_audio = read_wave(path, sr=hp.data.sr)
|
132 |
+
vad = webrtcvad.Vad(int(aggressiveness))
|
133 |
+
frames = frame_generator(20, byte_audio, hp.data.sr)
|
134 |
+
frames = list(frames)
|
135 |
+
times = vad_collector(hp.data.sr, 20, 200, vad, frames)
|
136 |
+
speech_times = []
|
137 |
+
speech_segs = []
|
138 |
+
for i, time in enumerate(times):
|
139 |
+
start = np.round(time[0],decimals=2)
|
140 |
+
end = np.round(time[1],decimals=2)
|
141 |
+
j = start
|
142 |
+
while j + .4 < end:
|
143 |
+
end_j = np.round(j+.4,decimals=2)
|
144 |
+
speech_times.append((j, end_j))
|
145 |
+
speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])
|
146 |
+
j = end_j
|
147 |
+
else:
|
148 |
+
speech_times.append((j, end))
|
149 |
+
speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])
|
150 |
+
return speech_times, speech_segs
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
speech_times, speech_segs = VAD_chunk(sys.argv[1], sys.argv[2])
|
utils/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
File without changes
|
utils/.ipynb_checkpoints/data_load-checkpoint.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Mostly copied from https://github.com/HarryVolek/PyTorch_Speaker_Verification
|
3 |
+
"""
|
4 |
+
import glob
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
from random import shuffle
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
|
12 |
+
from utils.hparam import hparam as hp
|
13 |
+
from utils.utils import mfccs_and_spec
|
14 |
+
|
15 |
+
class GujaratiSpeakerVerificationDataset(Dataset):
|
16 |
+
|
17 |
+
def __init__(self, shuffle=True, utter_start=0, split='train'):
|
18 |
+
# data path
|
19 |
+
if split!='val':
|
20 |
+
self.path = hp.data.train_path
|
21 |
+
self.utter_num = hp.train.M
|
22 |
+
else:
|
23 |
+
self.path = hp.data.test_path
|
24 |
+
self.utter_num = hp.test.M
|
25 |
+
self.file_list = os.listdir(self.path)
|
26 |
+
self.shuffle=shuffle
|
27 |
+
self.utter_start = utter_start
|
28 |
+
self.split = split
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.file_list)
|
32 |
+
|
33 |
+
def __getitem__(self, idx):
|
34 |
+
|
35 |
+
np_file_list = os.listdir(self.path)
|
36 |
+
|
37 |
+
if self.shuffle:
|
38 |
+
selected_file = random.sample(np_file_list, 1)[0] # select random speaker
|
39 |
+
else:
|
40 |
+
selected_file = np_file_list[idx]
|
41 |
+
|
42 |
+
utters = np.load(os.path.join(self.path, selected_file))
|
43 |
+
|
44 |
+
# load utterance spectrogram of selected speaker
|
45 |
+
if self.shuffle:
|
46 |
+
utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker
|
47 |
+
utterance = utters[utter_index]
|
48 |
+
else:
|
49 |
+
utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames]
|
50 |
+
|
51 |
+
utterance = utterance[:,:,:160] # TODO implement variable length batch size
|
52 |
+
|
53 |
+
utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels]
|
54 |
+
return utterance
|
55 |
+
|
56 |
+
def __repr__(self):
|
57 |
+
return f"{self.__class__.__name__}(split={self.split!r}, num_speakers={len(self.file_list)}, num_utterances={self.utter_num})"
|
utils/.ipynb_checkpoints/evaluation-checkpoint.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from tqdm.auto import tqdm
|
3 |
+
import os
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import random
|
8 |
+
from numpy.linalg import norm
|
9 |
+
|
10 |
+
from utils.VAD_segments import VAD_chunk
|
11 |
+
from utils.hparam import hparam as hp
|
12 |
+
|
13 |
+
class GujaratiSpeakerVerificationDatasetTest(Dataset):
|
14 |
+
def __init__(self, path, shuffle=True, utter_start=0):
|
15 |
+
# data path
|
16 |
+
self.path = path
|
17 |
+
self.file_list = os.listdir(self.path)
|
18 |
+
self.shuffle=shuffle
|
19 |
+
self.utter_start = utter_start
|
20 |
+
self.utter_num = 4
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.file_list)
|
24 |
+
|
25 |
+
def __getitem__(self, idx):
|
26 |
+
|
27 |
+
np_file_list = self.file_list
|
28 |
+
|
29 |
+
selected_file = np_file_list[idx]
|
30 |
+
|
31 |
+
utters = np.load(os.path.join(self.path, selected_file))
|
32 |
+
|
33 |
+
# load utterance spectrogram of selected speaker
|
34 |
+
if self.shuffle:
|
35 |
+
utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker
|
36 |
+
utterance = utters[utter_index]
|
37 |
+
else:
|
38 |
+
utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames]
|
39 |
+
|
40 |
+
utterance = utterance[:,:,:160] # TODO implement variable length batch size
|
41 |
+
|
42 |
+
utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels]
|
43 |
+
return utterance
|
44 |
+
|
45 |
+
def concat_segs(times, segs):
|
46 |
+
concat_seg = []
|
47 |
+
seg_concat = segs[0]
|
48 |
+
for i in range(0, len(times)-1):
|
49 |
+
if times[i][1] == times[i+1][0]:
|
50 |
+
seg_concat = np.concatenate((seg_concat, segs[i+1]))
|
51 |
+
else:
|
52 |
+
concat_seg.append(seg_concat)
|
53 |
+
seg_concat = segs[i+1]
|
54 |
+
else:
|
55 |
+
concat_seg.append(seg_concat)
|
56 |
+
return concat_seg
|
57 |
+
|
58 |
+
|
59 |
+
def get_STFTs(segs):
|
60 |
+
sr = 16000
|
61 |
+
STFT_frames = []
|
62 |
+
for seg in segs:
|
63 |
+
S = librosa.core.stft(y=seg, n_fft=hp.data.nfft,
|
64 |
+
win_length=int(hp.data.window * sr), hop_length=int(hp.data.hop * sr))
|
65 |
+
S = np.abs(S)**2
|
66 |
+
mel_basis = librosa.filters.mel(sr=sr, n_fft=hp.data.nfft, n_mels=hp.data.nmels)
|
67 |
+
S = np.log10(np.dot(mel_basis, S) + 1e-6)
|
68 |
+
for j in range(0, S.shape[1], int(.12/hp.data.hop)):
|
69 |
+
if j + 24 < S.shape[1]:
|
70 |
+
STFT_frames.append(S[:, j:j+24])
|
71 |
+
else:
|
72 |
+
break
|
73 |
+
return STFT_frames
|
74 |
+
|
75 |
+
|
76 |
+
def get_embedding(file_path, embedder_net, device, n_threshold=-1):
|
77 |
+
times, segs = VAD_chunk(2, file_path)
|
78 |
+
if not segs:
|
79 |
+
print(f'No voice activity detected in {file_path}')
|
80 |
+
return None
|
81 |
+
concat_seg = concat_segs(times, segs)
|
82 |
+
if not concat_seg:
|
83 |
+
print(f'No concatenated segments for {file_path}')
|
84 |
+
return None
|
85 |
+
STFT_frames = get_STFTs(concat_seg)
|
86 |
+
if not STFT_frames:
|
87 |
+
#print(f'No STFT frames for {file_path}')
|
88 |
+
return None
|
89 |
+
STFT_frames = np.stack(STFT_frames, axis=2)
|
90 |
+
STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)
|
91 |
+
|
92 |
+
with torch.no_grad():
|
93 |
+
embeddings = embedder_net(STFT_frames)
|
94 |
+
embeddings = embeddings[:n_threshold, :]
|
95 |
+
|
96 |
+
avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()
|
97 |
+
return avg_embedding
|
98 |
+
|
99 |
+
def get_speaker_embeddings_listdir(embedder_net, device, list_dir, k):
|
100 |
+
speaker_embeddings = {}
|
101 |
+
for speaker_name in tqdm(list_dir, leave = False):
|
102 |
+
speaker_dir = speaker_name
|
103 |
+
if os.path.isdir(speaker_dir) and speaker_dir[0] != ".DS_Store":
|
104 |
+
speaker_embeddings[speaker_name] = []
|
105 |
+
for i in range(10):
|
106 |
+
embeddings = []
|
107 |
+
audio_files = [os.path.join(speaker_dir, f) for f in os.listdir(speaker_dir) if f.endswith('.wav')]
|
108 |
+
random.shuffle(audio_files)
|
109 |
+
count = 0
|
110 |
+
iter_ = 0
|
111 |
+
while(count <= k):
|
112 |
+
file_path = audio_files[iter_]
|
113 |
+
embedding = get_embedding(file_path, embedder_net, device)
|
114 |
+
try:
|
115 |
+
_ = embedding.shape
|
116 |
+
embeddings.append(embedding)
|
117 |
+
count+=1
|
118 |
+
iter_+=1
|
119 |
+
except:
|
120 |
+
iter_+=1
|
121 |
+
speaker_embeddings[speaker_name].append(np.mean(embeddings, axis=0))
|
122 |
+
return speaker_embeddings
|
123 |
+
|
124 |
+
def create_pairs(speaker_embeddings):
|
125 |
+
pairs = []
|
126 |
+
labels = []
|
127 |
+
speakers = list(speaker_embeddings.keys())
|
128 |
+
|
129 |
+
for i in range(len(speakers)):
|
130 |
+
for j in range(len(speakers)):
|
131 |
+
for k1 in range(10):
|
132 |
+
for k2 in range(10):
|
133 |
+
emb1 = speaker_embeddings[speakers[i]][k1]
|
134 |
+
emb2 = speaker_embeddings[speakers[j]][k2]
|
135 |
+
pairs.append((emb1, emb2))
|
136 |
+
if i == j and not((emb1 == emb2).all()):
|
137 |
+
labels.append(1) # Same speaker
|
138 |
+
else:
|
139 |
+
labels.append(0) # Different speakers
|
140 |
+
return pairs, labels
|
141 |
+
|
142 |
+
class EmbeddingPairDataset(Dataset):
|
143 |
+
def __init__(self, pairs, labels):
|
144 |
+
self.pairs = pairs
|
145 |
+
self.labels = labels
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return len(self.pairs)
|
149 |
+
|
150 |
+
def __getitem__(self, idx):
|
151 |
+
emb1, emb2 = self.pairs[idx]
|
152 |
+
label = self.labels[idx]
|
153 |
+
|
154 |
+
emb1, emb2 = torch.tensor(emb1, dtype=torch.float32), torch.tensor(emb2, dtype=torch.float32)
|
155 |
+
|
156 |
+
concatenated = torch.cat((emb1, emb2), dim=1)
|
157 |
+
|
158 |
+
return concatenated.squeeze(), torch.tensor(label, dtype=torch.float32)
|
159 |
+
|
160 |
+
def __len__(self):
|
161 |
+
return len(self.labels)
|
162 |
+
|
163 |
+
def __repr__(self):
|
164 |
+
return f"{self.__class__.__name__}(length={self.__len__()})"
|
165 |
+
|
166 |
+
|
167 |
+
def cosine_similarity(A, B):
|
168 |
+
A = A.flatten().astype(np.float64)
|
169 |
+
B = B.flatten().astype(np.float64)
|
170 |
+
cosine = np.dot(A,B)/(norm(A)*norm(B))
|
171 |
+
return cosine
|
172 |
+
|
173 |
+
|
174 |
+
def create_subset(dataset, num_zeros):
|
175 |
+
pairs = dataset.pairs
|
176 |
+
labels = dataset.labels
|
177 |
+
|
178 |
+
pairs_1 = [pairs[i] for i in range(len(pairs)) if labels[i] == 1]
|
179 |
+
labels_1 = [labels[i] for i in range(len(labels)) if labels[i] == 1]
|
180 |
+
|
181 |
+
pairs_0 = [pairs[i] for i in range(len(pairs)) if labels[i] == 0]
|
182 |
+
labels_0 = [labels[i] for i in range(len(labels)) if labels[i] == 0]
|
183 |
+
|
184 |
+
num_zeros = min(num_zeros, len(pairs_0))
|
185 |
+
|
186 |
+
pairs_0 = pairs_0[:num_zeros]
|
187 |
+
labels_0 = labels_0[:num_zeros]
|
188 |
+
|
189 |
+
filtered_pairs = pairs_1 + pairs_0
|
190 |
+
filtered_labels = labels_1 + labels_0
|
191 |
+
|
192 |
+
return filtered_pairs, filtered_labels
|
utils/.ipynb_checkpoints/hparam-checkpoint.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#!/usr/bin/env python
|
3 |
+
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
def load_hparam(filename):
|
7 |
+
stream = open(filename, 'r')
|
8 |
+
docs = yaml.load_all(stream, Loader=yaml.Loader)
|
9 |
+
hparam_dict = dict()
|
10 |
+
for doc in docs:
|
11 |
+
for k, v in doc.items():
|
12 |
+
hparam_dict[k] = v
|
13 |
+
return hparam_dict
|
14 |
+
|
15 |
+
def merge_dict(user, default):
|
16 |
+
if isinstance(user, dict) and isinstance(default, dict):
|
17 |
+
for k, v in default.items():
|
18 |
+
if k not in user:
|
19 |
+
user[k] = v
|
20 |
+
else:
|
21 |
+
user[k] = merge_dict(user[k], v)
|
22 |
+
return user
|
23 |
+
|
24 |
+
|
25 |
+
class Dotdict(dict):
|
26 |
+
"""
|
27 |
+
a dictionary that supports dot notation
|
28 |
+
as well as dictionary access notation
|
29 |
+
usage: d = DotDict() or d = DotDict({'val1':'first'})
|
30 |
+
set attributes: d.val2 = 'second' or d['val2'] = 'second'
|
31 |
+
get attributes: d.val2 or d['val2']
|
32 |
+
"""
|
33 |
+
__getattr__ = dict.__getitem__
|
34 |
+
__setattr__ = dict.__setitem__
|
35 |
+
__delattr__ = dict.__delitem__
|
36 |
+
|
37 |
+
def __init__(self, dct=None):
|
38 |
+
dct = dict() if not dct else dct
|
39 |
+
for key, value in dct.items():
|
40 |
+
if hasattr(value, 'keys'):
|
41 |
+
value = Dotdict(value)
|
42 |
+
self[key] = value
|
43 |
+
|
44 |
+
|
45 |
+
class Hparam(Dotdict):
|
46 |
+
|
47 |
+
def __init__(self, file='config/config.yaml'):
|
48 |
+
super(Dotdict, self).__init__()
|
49 |
+
hp_dict = load_hparam(file)
|
50 |
+
hp_dotdict = Dotdict(hp_dict)
|
51 |
+
for k, v in hp_dotdict.items():
|
52 |
+
setattr(self, k, v)
|
53 |
+
|
54 |
+
__getattr__ = Dotdict.__getitem__
|
55 |
+
__setattr__ = Dotdict.__setitem__
|
56 |
+
__delattr__ = Dotdict.__delitem__
|
57 |
+
|
58 |
+
|
59 |
+
hparam = Hparam()
|
utils/.ipynb_checkpoints/kan-checkpoint.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
class KANLinear(torch.nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_features,
|
10 |
+
out_features,
|
11 |
+
grid_size=5,
|
12 |
+
spline_order=3,
|
13 |
+
scale_noise=0.1,
|
14 |
+
scale_base=1.0,
|
15 |
+
scale_spline=1.0,
|
16 |
+
enable_standalone_scale_spline=True,
|
17 |
+
base_activation=torch.nn.SiLU,
|
18 |
+
grid_eps=0.02,
|
19 |
+
grid_range=[-1, 1],
|
20 |
+
):
|
21 |
+
super(KANLinear, self).__init__()
|
22 |
+
self.in_features = in_features
|
23 |
+
self.out_features = out_features
|
24 |
+
self.grid_size = grid_size
|
25 |
+
self.spline_order = spline_order
|
26 |
+
|
27 |
+
h = (grid_range[1] - grid_range[0]) / grid_size
|
28 |
+
grid = (
|
29 |
+
(
|
30 |
+
torch.arange(-spline_order, grid_size + spline_order + 1) * h
|
31 |
+
+ grid_range[0]
|
32 |
+
)
|
33 |
+
.expand(in_features, -1)
|
34 |
+
.contiguous()
|
35 |
+
)
|
36 |
+
self.register_buffer("grid", grid)
|
37 |
+
|
38 |
+
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
|
39 |
+
self.spline_weight = torch.nn.Parameter(
|
40 |
+
torch.Tensor(out_features, in_features, grid_size + spline_order)
|
41 |
+
)
|
42 |
+
if enable_standalone_scale_spline:
|
43 |
+
self.spline_scaler = torch.nn.Parameter(
|
44 |
+
torch.Tensor(out_features, in_features)
|
45 |
+
)
|
46 |
+
|
47 |
+
self.scale_noise = scale_noise
|
48 |
+
self.scale_base = scale_base
|
49 |
+
self.scale_spline = scale_spline
|
50 |
+
self.enable_standalone_scale_spline = enable_standalone_scale_spline
|
51 |
+
self.base_activation = base_activation()
|
52 |
+
self.grid_eps = grid_eps
|
53 |
+
|
54 |
+
self.reset_parameters()
|
55 |
+
|
56 |
+
def reset_parameters(self):
|
57 |
+
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
|
58 |
+
with torch.no_grad():
|
59 |
+
noise = (
|
60 |
+
(
|
61 |
+
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
|
62 |
+
- 1 / 2
|
63 |
+
)
|
64 |
+
* self.scale_noise
|
65 |
+
/ self.grid_size
|
66 |
+
)
|
67 |
+
self.spline_weight.data.copy_(
|
68 |
+
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
|
69 |
+
* self.curve2coeff(
|
70 |
+
self.grid.T[self.spline_order : -self.spline_order],
|
71 |
+
noise,
|
72 |
+
)
|
73 |
+
)
|
74 |
+
if self.enable_standalone_scale_spline:
|
75 |
+
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
|
76 |
+
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
|
77 |
+
|
78 |
+
def b_splines(self, x: torch.Tensor):
|
79 |
+
"""
|
80 |
+
Compute the B-spline bases for the given input tensor.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
|
87 |
+
"""
|
88 |
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
89 |
+
|
90 |
+
grid: torch.Tensor = (
|
91 |
+
self.grid
|
92 |
+
) # (in_features, grid_size + 2 * spline_order + 1)
|
93 |
+
x = x.unsqueeze(-1)
|
94 |
+
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
|
95 |
+
for k in range(1, self.spline_order + 1):
|
96 |
+
bases = (
|
97 |
+
(x - grid[:, : -(k + 1)])
|
98 |
+
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
|
99 |
+
* bases[:, :, :-1]
|
100 |
+
) + (
|
101 |
+
(grid[:, k + 1 :] - x)
|
102 |
+
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
|
103 |
+
* bases[:, :, 1:]
|
104 |
+
)
|
105 |
+
|
106 |
+
assert bases.size() == (
|
107 |
+
x.size(0),
|
108 |
+
self.in_features,
|
109 |
+
self.grid_size + self.spline_order,
|
110 |
+
)
|
111 |
+
return bases.contiguous()
|
112 |
+
|
113 |
+
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
|
114 |
+
"""
|
115 |
+
Compute the coefficients of the curve that interpolates the given points.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
119 |
+
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
|
123 |
+
"""
|
124 |
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
125 |
+
assert y.size() == (x.size(0), self.in_features, self.out_features)
|
126 |
+
|
127 |
+
A = self.b_splines(x).transpose(
|
128 |
+
0, 1
|
129 |
+
) # (in_features, batch_size, grid_size + spline_order)
|
130 |
+
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
|
131 |
+
solution = torch.linalg.lstsq(
|
132 |
+
A, B
|
133 |
+
).solution # (in_features, grid_size + spline_order, out_features)
|
134 |
+
result = solution.permute(
|
135 |
+
2, 0, 1
|
136 |
+
) # (out_features, in_features, grid_size + spline_order)
|
137 |
+
|
138 |
+
assert result.size() == (
|
139 |
+
self.out_features,
|
140 |
+
self.in_features,
|
141 |
+
self.grid_size + self.spline_order,
|
142 |
+
)
|
143 |
+
return result.contiguous()
|
144 |
+
|
145 |
+
@property
|
146 |
+
def scaled_spline_weight(self):
|
147 |
+
return self.spline_weight * (
|
148 |
+
self.spline_scaler.unsqueeze(-1)
|
149 |
+
if self.enable_standalone_scale_spline
|
150 |
+
else 1.0
|
151 |
+
)
|
152 |
+
|
153 |
+
def forward(self, x: torch.Tensor):
|
154 |
+
assert x.size(-1) == self.in_features
|
155 |
+
original_shape = x.shape
|
156 |
+
x = x.view(-1, self.in_features)
|
157 |
+
|
158 |
+
base_output = F.linear(self.base_activation(x), self.base_weight)
|
159 |
+
spline_output = F.linear(
|
160 |
+
self.b_splines(x).view(x.size(0), -1),
|
161 |
+
self.scaled_spline_weight.view(self.out_features, -1),
|
162 |
+
)
|
163 |
+
output = base_output + spline_output
|
164 |
+
|
165 |
+
output = output.view(*original_shape[:-1], self.out_features)
|
166 |
+
return output
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def update_grid(self, x: torch.Tensor, margin=0.01):
|
170 |
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
171 |
+
batch = x.size(0)
|
172 |
+
|
173 |
+
splines = self.b_splines(x) # (batch, in, coeff)
|
174 |
+
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
|
175 |
+
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
|
176 |
+
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
|
177 |
+
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
|
178 |
+
unreduced_spline_output = unreduced_spline_output.permute(
|
179 |
+
1, 0, 2
|
180 |
+
) # (batch, in, out)
|
181 |
+
|
182 |
+
# sort each channel individually to collect data distribution
|
183 |
+
x_sorted = torch.sort(x, dim=0)[0]
|
184 |
+
grid_adaptive = x_sorted[
|
185 |
+
torch.linspace(
|
186 |
+
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
|
187 |
+
)
|
188 |
+
]
|
189 |
+
|
190 |
+
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
|
191 |
+
grid_uniform = (
|
192 |
+
torch.arange(
|
193 |
+
self.grid_size + 1, dtype=torch.float32, device=x.device
|
194 |
+
).unsqueeze(1)
|
195 |
+
* uniform_step
|
196 |
+
+ x_sorted[0]
|
197 |
+
- margin
|
198 |
+
)
|
199 |
+
|
200 |
+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
201 |
+
grid = torch.concatenate(
|
202 |
+
[
|
203 |
+
grid[:1]
|
204 |
+
- uniform_step
|
205 |
+
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
|
206 |
+
grid,
|
207 |
+
grid[-1:]
|
208 |
+
+ uniform_step
|
209 |
+
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
|
210 |
+
],
|
211 |
+
dim=0,
|
212 |
+
)
|
213 |
+
|
214 |
+
self.grid.copy_(grid.T)
|
215 |
+
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
|
216 |
+
|
217 |
+
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
218 |
+
"""
|
219 |
+
Compute the regularization loss.
|
220 |
+
|
221 |
+
This is a dumb simulation of the original L1 regularization as stated in the
|
222 |
+
paper, since the original one requires computing absolutes and entropy from the
|
223 |
+
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
|
224 |
+
behind the F.linear function if we want an memory efficient implementation.
|
225 |
+
|
226 |
+
The L1 regularization is now computed as mean absolute value of the spline
|
227 |
+
weights. The authors implementation also includes this term in addition to the
|
228 |
+
sample-based regularization.
|
229 |
+
"""
|
230 |
+
l1_fake = self.spline_weight.abs().mean(-1)
|
231 |
+
regularization_loss_activation = l1_fake.sum()
|
232 |
+
p = l1_fake / regularization_loss_activation
|
233 |
+
regularization_loss_entropy = -torch.sum(p * p.log())
|
234 |
+
return (
|
235 |
+
regularize_activation * regularization_loss_activation
|
236 |
+
+ regularize_entropy * regularization_loss_entropy
|
237 |
+
)
|
238 |
+
|
239 |
+
|
240 |
+
class KAN(torch.nn.Module):
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
layers_hidden,
|
244 |
+
grid_size=5,
|
245 |
+
spline_order=3,
|
246 |
+
scale_noise=0.1,
|
247 |
+
scale_base=1.0,
|
248 |
+
scale_spline=1.0,
|
249 |
+
base_activation=torch.nn.SiLU,
|
250 |
+
grid_eps=0.02,
|
251 |
+
grid_range=[-1, 1],
|
252 |
+
):
|
253 |
+
super(KAN, self).__init__()
|
254 |
+
self.grid_size = grid_size
|
255 |
+
self.spline_order = spline_order
|
256 |
+
|
257 |
+
self.layers = torch.nn.ModuleList()
|
258 |
+
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
|
259 |
+
self.layers.append(
|
260 |
+
KANLinear(
|
261 |
+
in_features,
|
262 |
+
out_features,
|
263 |
+
grid_size=grid_size,
|
264 |
+
spline_order=spline_order,
|
265 |
+
scale_noise=scale_noise,
|
266 |
+
scale_base=scale_base,
|
267 |
+
scale_spline=scale_spline,
|
268 |
+
base_activation=base_activation,
|
269 |
+
grid_eps=grid_eps,
|
270 |
+
grid_range=grid_range,
|
271 |
+
)
|
272 |
+
)
|
273 |
+
|
274 |
+
def forward(self, x: torch.Tensor, update_grid=False):
|
275 |
+
for layer in self.layers:
|
276 |
+
if update_grid:
|
277 |
+
layer.update_grid(x)
|
278 |
+
x = layer(x)
|
279 |
+
return x
|
280 |
+
|
281 |
+
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
282 |
+
return sum(
|
283 |
+
layer.regularization_loss(regularize_activation, regularize_entropy)
|
284 |
+
for layer in self.layers
|
285 |
+
)
|
utils/.ipynb_checkpoints/speech_embedder_net-checkpoint.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Wed Sep 5 20:58:34 2018
|
5 |
+
|
6 |
+
@author: harry
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from utils.hparam import hparam as hp
|
13 |
+
from utils.utils import get_centroids, get_cossim, calc_loss
|
14 |
+
from utils.kan import KANLinear
|
15 |
+
|
16 |
+
class SpeechEmbedder(nn.Module):
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
super(SpeechEmbedder, self).__init__()
|
20 |
+
self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
|
21 |
+
for name, param in self.LSTM_stack.named_parameters():
|
22 |
+
if 'bias' in name:
|
23 |
+
nn.init.constant_(param, 0.0)
|
24 |
+
elif 'weight' in name:
|
25 |
+
nn.init.xavier_normal_(param)
|
26 |
+
self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
|
30 |
+
#only use last frame
|
31 |
+
x = x[:,x.size(1)-1]
|
32 |
+
x = self.projection(x.float())
|
33 |
+
x = x / torch.norm(x, dim=1).unsqueeze(1)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class SpeechEmbedderGRU(nn.Module):
|
38 |
+
def __init__(self):
|
39 |
+
super(SpeechEmbedderGRU, self).__init__()
|
40 |
+
self.GRU_stack = nn.GRU(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
|
41 |
+
for name, param in self.GRU_stack.named_parameters():
|
42 |
+
if 'bias' in name:
|
43 |
+
nn.init.constant_(param, 0.0)
|
44 |
+
elif 'weight' in name:
|
45 |
+
nn.init.xavier_normal_(param)
|
46 |
+
self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x, _ = self.GRU_stack(x.float()) #(batch, frames, n_mels)
|
50 |
+
#only use last frame
|
51 |
+
x = x[:,x.size(1)-1]
|
52 |
+
x = self.projection(x.float())
|
53 |
+
x = x / torch.norm(x, dim=1).unsqueeze(1)
|
54 |
+
return x
|
55 |
+
|
56 |
+
class SpeechEmbedderKAN(nn.Module):
|
57 |
+
def __init__(self):
|
58 |
+
super(SpeechEmbedderKAN, self).__init__()
|
59 |
+
self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
|
60 |
+
for name, param in self.LSTM_stack.named_parameters():
|
61 |
+
if 'bias' in name:
|
62 |
+
nn.init.constant_(param, 0.0)
|
63 |
+
elif 'weight' in name:
|
64 |
+
nn.init.xavier_normal_(param)
|
65 |
+
self.projection = KANLinear(hp.model.hidden, hp.model.proj)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
|
69 |
+
#only use last frame
|
70 |
+
x = x[:,x.size(1)-1]
|
71 |
+
x = self.projection(x.float())
|
72 |
+
x = x / torch.norm(x, dim=1).unsqueeze(1)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
class SpeechEmbedderBidirectional(nn.Module):
|
78 |
+
def __init__(self):
|
79 |
+
super(SpeechEmbedderBidirectional, self).__init__()
|
80 |
+
self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True, bidirectional=True)
|
81 |
+
for name, param in self.LSTM_stack.named_parameters():
|
82 |
+
if 'bias' in name:
|
83 |
+
nn.init.constant_(param, 0.0)
|
84 |
+
elif 'weight' in name:
|
85 |
+
nn.init.xavier_normal_(param)
|
86 |
+
self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
|
90 |
+
#only use last frame
|
91 |
+
x = x[:, :, :hp.model.hidden]
|
92 |
+
|
93 |
+
x = x[:,x.size(1)-1]
|
94 |
+
x = self.projection(x.float())
|
95 |
+
x = x / torch.norm(x, dim=1).unsqueeze(1)
|
96 |
+
return x
|
97 |
+
|
98 |
+
class GE2ELoss(nn.Module):
|
99 |
+
|
100 |
+
def __init__(self, device):
|
101 |
+
super(GE2ELoss, self).__init__()
|
102 |
+
self.w = nn.Parameter(torch.tensor(10.0).to(device), requires_grad=True)
|
103 |
+
self.b = nn.Parameter(torch.tensor(-5.0).to(device), requires_grad=True)
|
104 |
+
self.device = device
|
105 |
+
|
106 |
+
def forward(self, embeddings):
|
107 |
+
torch.clamp(self.w, 1e-6)
|
108 |
+
centroids = get_centroids(embeddings)
|
109 |
+
cossim = get_cossim(embeddings, centroids)
|
110 |
+
sim_matrix = self.w*cossim.to(self.device) + self.b
|
111 |
+
loss, _ = calc_loss(sim_matrix)
|
112 |
+
return loss
|
utils/.ipynb_checkpoints/utils-checkpoint.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Thu Sep 20 16:56:19 2018
|
5 |
+
|
6 |
+
@author: harry
|
7 |
+
"""
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.autograd as grad
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from utils.hparam import hparam as hp
|
15 |
+
|
16 |
+
def get_centroids_prior(embeddings):
|
17 |
+
centroids = []
|
18 |
+
for speaker in embeddings:
|
19 |
+
centroid = 0
|
20 |
+
for utterance in speaker:
|
21 |
+
centroid = centroid + utterance
|
22 |
+
centroid = centroid/len(speaker)
|
23 |
+
centroids.append(centroid)
|
24 |
+
centroids = torch.stack(centroids)
|
25 |
+
return centroids
|
26 |
+
|
27 |
+
def get_centroids(embeddings):
|
28 |
+
centroids = embeddings.mean(dim=1)
|
29 |
+
return centroids
|
30 |
+
|
31 |
+
def get_centroid(embeddings, speaker_num, utterance_num):
|
32 |
+
centroid = 0
|
33 |
+
for utterance_id, utterance in enumerate(embeddings[speaker_num]):
|
34 |
+
if utterance_id == utterance_num:
|
35 |
+
continue
|
36 |
+
centroid = centroid + utterance
|
37 |
+
centroid = centroid/(len(embeddings[speaker_num])-1)
|
38 |
+
return centroid
|
39 |
+
|
40 |
+
def get_utterance_centroids(embeddings):
|
41 |
+
"""
|
42 |
+
Returns the centroids for each utterance of a speaker, where
|
43 |
+
the utterance centroid is the speaker centroid without considering
|
44 |
+
this utterance
|
45 |
+
|
46 |
+
Shape of embeddings should be:
|
47 |
+
(speaker_ct, utterance_per_speaker_ct, embedding_size)
|
48 |
+
"""
|
49 |
+
sum_centroids = embeddings.sum(dim=1)
|
50 |
+
# we want to subtract out each utterance, prior to calculating the
|
51 |
+
# the utterance centroid
|
52 |
+
sum_centroids = sum_centroids.reshape(
|
53 |
+
sum_centroids.shape[0], 1, sum_centroids.shape[-1]
|
54 |
+
)
|
55 |
+
# we want the mean but not including the utterance itself, so -1
|
56 |
+
num_utterances = embeddings.shape[1] - 1
|
57 |
+
centroids = (sum_centroids - embeddings) / num_utterances
|
58 |
+
return centroids
|
59 |
+
|
60 |
+
def get_cossim_prior(embeddings, centroids):
|
61 |
+
# Calculates cosine similarity matrix. Requires (N, M, feature) input
|
62 |
+
cossim = torch.zeros(embeddings.size(0),embeddings.size(1),centroids.size(0))
|
63 |
+
for speaker_num, speaker in enumerate(embeddings):
|
64 |
+
for utterance_num, utterance in enumerate(speaker):
|
65 |
+
for centroid_num, centroid in enumerate(centroids):
|
66 |
+
if speaker_num == centroid_num:
|
67 |
+
centroid = get_centroid(embeddings, speaker_num, utterance_num)
|
68 |
+
output = F.cosine_similarity(utterance,centroid,dim=0)+1e-6
|
69 |
+
cossim[speaker_num][utterance_num][centroid_num] = output
|
70 |
+
return cossim
|
71 |
+
|
72 |
+
def get_cossim(embeddings, centroids):
|
73 |
+
# number of utterances per speaker
|
74 |
+
num_utterances = embeddings.shape[1]
|
75 |
+
utterance_centroids = get_utterance_centroids(embeddings)
|
76 |
+
|
77 |
+
# flatten the embeddings and utterance centroids to just utterance,
|
78 |
+
# so we can do cosine similarity
|
79 |
+
utterance_centroids_flat = utterance_centroids.view(
|
80 |
+
utterance_centroids.shape[0] * utterance_centroids.shape[1],
|
81 |
+
-1
|
82 |
+
)
|
83 |
+
embeddings_flat = embeddings.view(
|
84 |
+
embeddings.shape[0] * num_utterances,
|
85 |
+
-1
|
86 |
+
)
|
87 |
+
# the cosine distance between utterance and the associated centroids
|
88 |
+
# for that utterance
|
89 |
+
# this is each speaker's utterances against his own centroid, but each
|
90 |
+
# comparison centroid has the current utterance removed
|
91 |
+
cos_same = F.cosine_similarity(embeddings_flat, utterance_centroids_flat)
|
92 |
+
|
93 |
+
# now we get the cosine distance between each utterance and the other speakers'
|
94 |
+
# centroids
|
95 |
+
# to do so requires comparing each utterance to each centroid. To keep the
|
96 |
+
# operation fast, we vectorize by using matrices L (embeddings) and
|
97 |
+
# R (centroids) where L has each utterance repeated sequentially for all
|
98 |
+
# comparisons and R has the entire centroids frame repeated for each utterance
|
99 |
+
centroids_expand = centroids.repeat((num_utterances * embeddings.shape[0], 1))
|
100 |
+
embeddings_expand = embeddings_flat.unsqueeze(1).repeat(1, embeddings.shape[0], 1)
|
101 |
+
embeddings_expand = embeddings_expand.view(
|
102 |
+
embeddings_expand.shape[0] * embeddings_expand.shape[1],
|
103 |
+
embeddings_expand.shape[-1]
|
104 |
+
)
|
105 |
+
cos_diff = F.cosine_similarity(embeddings_expand, centroids_expand)
|
106 |
+
cos_diff = cos_diff.view(
|
107 |
+
embeddings.size(0),
|
108 |
+
num_utterances,
|
109 |
+
centroids.size(0)
|
110 |
+
)
|
111 |
+
# assign the cosine distance for same speakers to the proper idx
|
112 |
+
same_idx = list(range(embeddings.size(0)))
|
113 |
+
cos_diff[same_idx, :, same_idx] = cos_same.view(embeddings.shape[0], num_utterances)
|
114 |
+
cos_diff = cos_diff + 1e-6
|
115 |
+
return cos_diff
|
116 |
+
|
117 |
+
def calc_loss_prior(sim_matrix):
|
118 |
+
# Calculates loss from (N, M, K) similarity matrix
|
119 |
+
per_embedding_loss = torch.zeros(sim_matrix.size(0), sim_matrix.size(1))
|
120 |
+
for j in range(len(sim_matrix)):
|
121 |
+
for i in range(sim_matrix.size(1)):
|
122 |
+
per_embedding_loss[j][i] = -(sim_matrix[j][i][j] - ((torch.exp(sim_matrix[j][i]).sum()+1e-6).log_()))
|
123 |
+
loss = per_embedding_loss.sum()
|
124 |
+
return loss, per_embedding_loss
|
125 |
+
|
126 |
+
def calc_loss(sim_matrix):
|
127 |
+
same_idx = list(range(sim_matrix.size(0)))
|
128 |
+
pos = sim_matrix[same_idx, :, same_idx]
|
129 |
+
neg = (torch.exp(sim_matrix).sum(dim=2) + 1e-6).log_()
|
130 |
+
per_embedding_loss = -1 * (pos - neg)
|
131 |
+
loss = per_embedding_loss.sum()
|
132 |
+
return loss, per_embedding_loss
|
133 |
+
|
134 |
+
def normalize_0_1(values, max_value, min_value):
|
135 |
+
normalized = np.clip((values - min_value) / (max_value - min_value), 0, 1)
|
136 |
+
return normalized
|
137 |
+
|
138 |
+
def mfccs_and_spec(wav_file, wav_process = False, calc_mfccs=False, calc_mag_db=False):
|
139 |
+
sound_file, _ = librosa.core.load(wav_file, sr=hp.data.sr)
|
140 |
+
window_length = int(hp.data.window*hp.data.sr)
|
141 |
+
hop_length = int(hp.data.hop*hp.data.sr)
|
142 |
+
duration = hp.data.tisv_frame * hp.data.hop + hp.data.window
|
143 |
+
|
144 |
+
# Cut silence and fix length
|
145 |
+
if wav_process == True:
|
146 |
+
sound_file, index = librosa.effects.trim(sound_file, frame_length=window_length, hop_length=hop_length)
|
147 |
+
length = int(hp.data.sr * duration)
|
148 |
+
sound_file = librosa.util.fix_length(sound_file, length)
|
149 |
+
|
150 |
+
spec = librosa.stft(sound_file, n_fft=hp.data.nfft, hop_length=hop_length, win_length=window_length)
|
151 |
+
mag_spec = np.abs(spec)
|
152 |
+
|
153 |
+
mel_basis = librosa.filters.mel(hp.data.sr, hp.data.nfft, n_mels=hp.data.nmels)
|
154 |
+
mel_spec = np.dot(mel_basis, mag_spec)
|
155 |
+
|
156 |
+
mag_db = librosa.amplitude_to_db(mag_spec)
|
157 |
+
#db mel spectrogram
|
158 |
+
mel_db = librosa.amplitude_to_db(mel_spec).T
|
159 |
+
|
160 |
+
mfccs = None
|
161 |
+
if calc_mfccs:
|
162 |
+
mfccs = np.dot(librosa.filters.dct(40, mel_db.shape[0]), mel_db).T
|
163 |
+
|
164 |
+
return mfccs, mel_db, mag_db
|
165 |
+
|
166 |
+
if __name__ == "__main__":
|
167 |
+
w = grad.Variable(torch.tensor(1.0))
|
168 |
+
b = grad.Variable(torch.tensor(0.0))
|
169 |
+
embeddings = torch.tensor([[0,1,0],[0,0,1], [0,1,0], [0,1,0], [1,0,0], [1,0,0]]).to(torch.float).reshape(3,2,3)
|
170 |
+
centroids = get_centroids(embeddings)
|
171 |
+
cossim = get_cossim(embeddings, centroids)
|
172 |
+
sim_matrix = w*cossim + b
|
173 |
+
loss, per_embedding_loss = calc_loss(sim_matrix)
|
utils/VAD_segments.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Tue Dec 18 16:22:41 2018
|
5 |
+
|
6 |
+
@author: Harry
|
7 |
+
Modified from https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
8 |
+
"""
|
9 |
+
|
10 |
+
import collections
|
11 |
+
import contextlib
|
12 |
+
import numpy as np
|
13 |
+
import sys
|
14 |
+
import librosa
|
15 |
+
import wave
|
16 |
+
|
17 |
+
import webrtcvad
|
18 |
+
|
19 |
+
from utils.hparam import hparam as hp
|
20 |
+
|
21 |
+
def read_wave(path, sr):
|
22 |
+
"""Reads a .wav file.
|
23 |
+
Takes the path, and returns (PCM audio data, sample rate).
|
24 |
+
Assumes sample width == 2
|
25 |
+
"""
|
26 |
+
with contextlib.closing(wave.open(path, 'rb')) as wf:
|
27 |
+
num_channels = wf.getnchannels()
|
28 |
+
assert num_channels == 1
|
29 |
+
sample_width = wf.getsampwidth()
|
30 |
+
assert sample_width == 2
|
31 |
+
sample_rate = wf.getframerate()
|
32 |
+
assert sample_rate in (8000, 16000, 32000, 48000)
|
33 |
+
pcm_data = wf.readframes(wf.getnframes())
|
34 |
+
data, _ = librosa.load(path, sr=sr)
|
35 |
+
assert len(data.shape) == 1
|
36 |
+
assert sr in (8000, 16000, 32000, 48000)
|
37 |
+
return data, pcm_data
|
38 |
+
|
39 |
+
class Frame(object):
|
40 |
+
"""Represents a "frame" of audio data."""
|
41 |
+
def __init__(self, bytes, timestamp, duration):
|
42 |
+
self.bytes = bytes
|
43 |
+
self.timestamp = timestamp
|
44 |
+
self.duration = duration
|
45 |
+
|
46 |
+
|
47 |
+
def frame_generator(frame_duration_ms, audio, sample_rate):
|
48 |
+
"""Generates audio frames from PCM audio data.
|
49 |
+
Takes the desired frame duration in milliseconds, the PCM data, and
|
50 |
+
the sample rate.
|
51 |
+
Yields Frames of the requested duration.
|
52 |
+
"""
|
53 |
+
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
|
54 |
+
offset = 0
|
55 |
+
timestamp = 0.0
|
56 |
+
duration = (float(n) / sample_rate) / 2.0
|
57 |
+
while offset + n < len(audio):
|
58 |
+
yield Frame(audio[offset:offset + n], timestamp, duration)
|
59 |
+
timestamp += duration
|
60 |
+
offset += n
|
61 |
+
|
62 |
+
|
63 |
+
def vad_collector(sample_rate, frame_duration_ms,
|
64 |
+
padding_duration_ms, vad, frames):
|
65 |
+
"""Filters out non-voiced audio frames.
|
66 |
+
Given a webrtcvad.Vad and a source of audio frames, yields only
|
67 |
+
the voiced audio.
|
68 |
+
Uses a padded, sliding window algorithm over the audio frames.
|
69 |
+
When more than 90% of the frames in the window are voiced (as
|
70 |
+
reported by the VAD), the collector triggers and begins yielding
|
71 |
+
audio frames. Then the collector waits until 90% of the frames in
|
72 |
+
the window are unvoiced to detrigger.
|
73 |
+
The window is padded at the front and back to provide a small
|
74 |
+
amount of silence or the beginnings/endings of speech around the
|
75 |
+
voiced frames.
|
76 |
+
Arguments:
|
77 |
+
sample_rate - The audio sample rate, in Hz.
|
78 |
+
frame_duration_ms - The frame duration in milliseconds.
|
79 |
+
padding_duration_ms - The amount to pad the window, in milliseconds.
|
80 |
+
vad - An instance of webrtcvad.Vad.
|
81 |
+
frames - a source of audio frames (sequence or generator).
|
82 |
+
Returns: A generator that yields PCM audio data.
|
83 |
+
"""
|
84 |
+
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
|
85 |
+
# We use a deque for our sliding window/ring buffer.
|
86 |
+
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
87 |
+
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
|
88 |
+
# NOTTRIGGERED state.
|
89 |
+
triggered = False
|
90 |
+
|
91 |
+
voiced_frames = []
|
92 |
+
for frame in frames:
|
93 |
+
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
94 |
+
|
95 |
+
if not triggered:
|
96 |
+
ring_buffer.append((frame, is_speech))
|
97 |
+
num_voiced = len([f for f, speech in ring_buffer if speech])
|
98 |
+
# If we're NOTTRIGGERED and more than 90% of the frames in
|
99 |
+
# the ring buffer are voiced frames, then enter the
|
100 |
+
# TRIGGERED state.
|
101 |
+
if num_voiced > 0.9 * ring_buffer.maxlen:
|
102 |
+
triggered = True
|
103 |
+
start = ring_buffer[0][0].timestamp
|
104 |
+
# We want to yield all the audio we see from now until
|
105 |
+
# we are NOTTRIGGERED, but we have to start with the
|
106 |
+
# audio that's already in the ring buffer.
|
107 |
+
for f, s in ring_buffer:
|
108 |
+
voiced_frames.append(f)
|
109 |
+
ring_buffer.clear()
|
110 |
+
else:
|
111 |
+
# We're in the TRIGGERED state, so collect the audio data
|
112 |
+
# and add it to the ring buffer.
|
113 |
+
voiced_frames.append(frame)
|
114 |
+
ring_buffer.append((frame, is_speech))
|
115 |
+
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
116 |
+
# If more than 90% of the frames in the ring buffer are
|
117 |
+
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
118 |
+
# audio we've collected.
|
119 |
+
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
120 |
+
triggered = False
|
121 |
+
yield (start, frame.timestamp + frame.duration)
|
122 |
+
ring_buffer.clear()
|
123 |
+
voiced_frames = []
|
124 |
+
# If we have any leftover voiced audio when we run out of input,
|
125 |
+
# yield it.
|
126 |
+
if voiced_frames:
|
127 |
+
yield (start, frame.timestamp + frame.duration)
|
128 |
+
|
129 |
+
|
130 |
+
def VAD_chunk(aggressiveness, path):
|
131 |
+
audio, byte_audio = read_wave(path, sr=hp.data.sr)
|
132 |
+
vad = webrtcvad.Vad(int(aggressiveness))
|
133 |
+
frames = frame_generator(20, byte_audio, hp.data.sr)
|
134 |
+
frames = list(frames)
|
135 |
+
times = vad_collector(hp.data.sr, 20, 200, vad, frames)
|
136 |
+
speech_times = []
|
137 |
+
speech_segs = []
|
138 |
+
for i, time in enumerate(times):
|
139 |
+
start = np.round(time[0],decimals=2)
|
140 |
+
end = np.round(time[1],decimals=2)
|
141 |
+
j = start
|
142 |
+
while j + .4 < end:
|
143 |
+
end_j = np.round(j+.4,decimals=2)
|
144 |
+
speech_times.append((j, end_j))
|
145 |
+
speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])
|
146 |
+
j = end_j
|
147 |
+
else:
|
148 |
+
speech_times.append((j, end))
|
149 |
+
speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])
|
150 |
+
return speech_times, speech_segs
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
speech_times, speech_segs = VAD_chunk(sys.argv[1], sys.argv[2])
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/VAD_segments.cpython-39.pyc
ADDED
Binary file (4.68 kB). View file
|
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (162 Bytes). View file
|
|
utils/__pycache__/data_load.cpython-39.pyc
ADDED
Binary file (2.05 kB). View file
|
|
utils/__pycache__/evaluation.cpython-39.pyc
ADDED
Binary file (6.62 kB). View file
|
|
utils/__pycache__/hparam.cpython-39.pyc
ADDED
Binary file (1.98 kB). View file
|
|
utils/__pycache__/kan.cpython-39.pyc
ADDED
Binary file (7.57 kB). View file
|
|
utils/__pycache__/speech_embedder_net.cpython-39.pyc
ADDED
Binary file (4.45 kB). View file
|
|
utils/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (4.7 kB). View file
|
|
utils/data_load.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Mostly copied from https://github.com/HarryVolek/PyTorch_Speaker_Verification
|
3 |
+
"""
|
4 |
+
import glob
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
from random import shuffle
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
|
12 |
+
from utils.hparam import hparam as hp
|
13 |
+
from utils.utils import mfccs_and_spec
|
14 |
+
|
15 |
+
class GujaratiSpeakerVerificationDataset(Dataset):
|
16 |
+
|
17 |
+
def __init__(self, shuffle=True, utter_start=0, split='train'):
|
18 |
+
# data path
|
19 |
+
if split!='val':
|
20 |
+
self.path = hp.data.train_path
|
21 |
+
self.utter_num = hp.train.M
|
22 |
+
else:
|
23 |
+
self.path = hp.data.test_path
|
24 |
+
self.utter_num = hp.test.M
|
25 |
+
self.file_list = os.listdir(self.path)
|
26 |
+
self.shuffle=shuffle
|
27 |
+
self.utter_start = utter_start
|
28 |
+
self.split = split
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.file_list)
|
32 |
+
|
33 |
+
def __getitem__(self, idx):
|
34 |
+
|
35 |
+
np_file_list = os.listdir(self.path)
|
36 |
+
|
37 |
+
if self.shuffle:
|
38 |
+
selected_file = random.sample(np_file_list, 1)[0] # select random speaker
|
39 |
+
else:
|
40 |
+
selected_file = np_file_list[idx]
|
41 |
+
|
42 |
+
utters = np.load(os.path.join(self.path, selected_file))
|
43 |
+
|
44 |
+
# load utterance spectrogram of selected speaker
|
45 |
+
if self.shuffle:
|
46 |
+
utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker
|
47 |
+
utterance = utters[utter_index]
|
48 |
+
else:
|
49 |
+
utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames]
|
50 |
+
|
51 |
+
utterance = utterance[:,:,:160] # TODO implement variable length batch size
|
52 |
+
|
53 |
+
utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels]
|
54 |
+
return utterance
|
55 |
+
|
56 |
+
def __repr__(self):
|
57 |
+
return f"{self.__class__.__name__}(split={self.split!r}, num_speakers={len(self.file_list)}, num_utterances={self.utter_num})"
|
utils/evaluation.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from tqdm.auto import tqdm
|
3 |
+
import os
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import random
|
8 |
+
from numpy.linalg import norm
|
9 |
+
|
10 |
+
from utils.VAD_segments import VAD_chunk
|
11 |
+
from utils.hparam import hparam as hp
|
12 |
+
|
13 |
+
class GujaratiSpeakerVerificationDatasetTest(Dataset):
|
14 |
+
def __init__(self, path, shuffle=True, utter_start=0):
|
15 |
+
# data path
|
16 |
+
self.path = path
|
17 |
+
self.file_list = os.listdir(self.path)
|
18 |
+
self.shuffle=shuffle
|
19 |
+
self.utter_start = utter_start
|
20 |
+
self.utter_num = 4
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.file_list)
|
24 |
+
|
25 |
+
def __getitem__(self, idx):
|
26 |
+
|
27 |
+
np_file_list = self.file_list
|
28 |
+
|
29 |
+
selected_file = np_file_list[idx]
|
30 |
+
|
31 |
+
utters = np.load(os.path.join(self.path, selected_file))
|
32 |
+
|
33 |
+
# load utterance spectrogram of selected speaker
|
34 |
+
if self.shuffle:
|
35 |
+
utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker
|
36 |
+
utterance = utters[utter_index]
|
37 |
+
else:
|
38 |
+
utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames]
|
39 |
+
|
40 |
+
utterance = utterance[:,:,:160] # TODO implement variable length batch size
|
41 |
+
|
42 |
+
utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels]
|
43 |
+
return utterance
|
44 |
+
|
45 |
+
def concat_segs(times, segs):
|
46 |
+
concat_seg = []
|
47 |
+
seg_concat = segs[0]
|
48 |
+
for i in range(0, len(times)-1):
|
49 |
+
if times[i][1] == times[i+1][0]:
|
50 |
+
seg_concat = np.concatenate((seg_concat, segs[i+1]))
|
51 |
+
else:
|
52 |
+
concat_seg.append(seg_concat)
|
53 |
+
seg_concat = segs[i+1]
|
54 |
+
else:
|
55 |
+
concat_seg.append(seg_concat)
|
56 |
+
return concat_seg
|
57 |
+
|
58 |
+
|
59 |
+
def get_STFTs(segs):
|
60 |
+
sr = 16000
|
61 |
+
STFT_frames = []
|
62 |
+
for seg in segs:
|
63 |
+
S = librosa.core.stft(y=seg, n_fft=hp.data.nfft,
|
64 |
+
win_length=int(hp.data.window * sr), hop_length=int(hp.data.hop * sr))
|
65 |
+
S = np.abs(S)**2
|
66 |
+
mel_basis = librosa.filters.mel(sr=sr, n_fft=hp.data.nfft, n_mels=hp.data.nmels)
|
67 |
+
S = np.log10(np.dot(mel_basis, S) + 1e-6)
|
68 |
+
for j in range(0, S.shape[1], int(.12/hp.data.hop)):
|
69 |
+
if j + 24 < S.shape[1]:
|
70 |
+
STFT_frames.append(S[:, j:j+24])
|
71 |
+
else:
|
72 |
+
break
|
73 |
+
return STFT_frames
|
74 |
+
|
75 |
+
|
76 |
+
def get_embedding(file_path, embedder_net, device, n_threshold=-1):
|
77 |
+
times, segs = VAD_chunk(2, file_path)
|
78 |
+
if not segs:
|
79 |
+
print(f'No voice activity detected in {file_path}')
|
80 |
+
return None
|
81 |
+
concat_seg = concat_segs(times, segs)
|
82 |
+
if not concat_seg:
|
83 |
+
print(f'No concatenated segments for {file_path}')
|
84 |
+
return None
|
85 |
+
STFT_frames = get_STFTs(concat_seg)
|
86 |
+
if not STFT_frames:
|
87 |
+
#print(f'No STFT frames for {file_path}')
|
88 |
+
return None
|
89 |
+
STFT_frames = np.stack(STFT_frames, axis=2)
|
90 |
+
STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)
|
91 |
+
|
92 |
+
with torch.no_grad():
|
93 |
+
embeddings = embedder_net(STFT_frames)
|
94 |
+
embeddings = embeddings[:n_threshold, :]
|
95 |
+
|
96 |
+
avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()
|
97 |
+
return avg_embedding
|
98 |
+
|
99 |
+
def get_speaker_embeddings_listdir(embedder_net, device, list_dir, k):
|
100 |
+
speaker_embeddings = {}
|
101 |
+
for speaker_name in tqdm(list_dir, leave = False):
|
102 |
+
speaker_dir = speaker_name
|
103 |
+
if os.path.isdir(speaker_dir) and speaker_dir[0] != ".DS_Store":
|
104 |
+
speaker_embeddings[speaker_name] = []
|
105 |
+
for i in range(10):
|
106 |
+
embeddings = []
|
107 |
+
audio_files = [os.path.join(speaker_dir, f) for f in os.listdir(speaker_dir) if f.endswith('.wav')]
|
108 |
+
random.shuffle(audio_files)
|
109 |
+
count = 0
|
110 |
+
iter_ = 0
|
111 |
+
while(count <= k):
|
112 |
+
file_path = audio_files[iter_]
|
113 |
+
embedding = get_embedding(file_path, embedder_net, device)
|
114 |
+
try:
|
115 |
+
_ = embedding.shape
|
116 |
+
embeddings.append(embedding)
|
117 |
+
count+=1
|
118 |
+
iter_+=1
|
119 |
+
except:
|
120 |
+
iter_+=1
|
121 |
+
speaker_embeddings[speaker_name].append(np.mean(embeddings, axis=0))
|
122 |
+
return speaker_embeddings
|
123 |
+
|
124 |
+
def create_pairs(speaker_embeddings):
|
125 |
+
pairs = []
|
126 |
+
labels = []
|
127 |
+
speakers = list(speaker_embeddings.keys())
|
128 |
+
|
129 |
+
for i in range(len(speakers)):
|
130 |
+
for j in range(len(speakers)):
|
131 |
+
for k1 in range(10):
|
132 |
+
for k2 in range(10):
|
133 |
+
emb1 = speaker_embeddings[speakers[i]][k1]
|
134 |
+
emb2 = speaker_embeddings[speakers[j]][k2]
|
135 |
+
pairs.append((emb1, emb2))
|
136 |
+
if i == j and not((emb1 == emb2).all()):
|
137 |
+
labels.append(1) # Same speaker
|
138 |
+
else:
|
139 |
+
labels.append(0) # Different speakers
|
140 |
+
return pairs, labels
|
141 |
+
|
142 |
+
class EmbeddingPairDataset(Dataset):
|
143 |
+
def __init__(self, pairs, labels):
|
144 |
+
self.pairs = pairs
|
145 |
+
self.labels = labels
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return len(self.pairs)
|
149 |
+
|
150 |
+
def __getitem__(self, idx):
|
151 |
+
emb1, emb2 = self.pairs[idx]
|
152 |
+
label = self.labels[idx]
|
153 |
+
|
154 |
+
emb1, emb2 = torch.tensor(emb1, dtype=torch.float32), torch.tensor(emb2, dtype=torch.float32)
|
155 |
+
|
156 |
+
concatenated = torch.cat((emb1, emb2), dim=1)
|
157 |
+
|
158 |
+
return concatenated.squeeze(), torch.tensor(label, dtype=torch.float32)
|
159 |
+
|
160 |
+
def __len__(self):
|
161 |
+
return len(self.labels)
|
162 |
+
|
163 |
+
def __repr__(self):
|
164 |
+
return f"{self.__class__.__name__}(length={self.__len__()})"
|
165 |
+
|
166 |
+
|
167 |
+
def cosine_similarity(A, B):
|
168 |
+
A = A.flatten().astype(np.float64)
|
169 |
+
B = B.flatten().astype(np.float64)
|
170 |
+
cosine = np.dot(A,B)/(norm(A)*norm(B))
|
171 |
+
return cosine
|
172 |
+
|
173 |
+
|
174 |
+
def create_subset(dataset, num_zeros):
|
175 |
+
pairs = dataset.pairs
|
176 |
+
labels = dataset.labels
|
177 |
+
|
178 |
+
pairs_1 = [pairs[i] for i in range(len(pairs)) if labels[i] == 1]
|
179 |
+
labels_1 = [labels[i] for i in range(len(labels)) if labels[i] == 1]
|
180 |
+
|
181 |
+
pairs_0 = [pairs[i] for i in range(len(pairs)) if labels[i] == 0]
|
182 |
+
labels_0 = [labels[i] for i in range(len(labels)) if labels[i] == 0]
|
183 |
+
|
184 |
+
num_zeros = min(num_zeros, len(pairs_0))
|
185 |
+
|
186 |
+
pairs_0 = pairs_0[:num_zeros]
|
187 |
+
labels_0 = labels_0[:num_zeros]
|
188 |
+
|
189 |
+
filtered_pairs = pairs_1 + pairs_0
|
190 |
+
filtered_labels = labels_1 + labels_0
|
191 |
+
|
192 |
+
return filtered_pairs, filtered_labels
|
utils/hparam.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#!/usr/bin/env python
|
3 |
+
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
def load_hparam(filename):
|
7 |
+
stream = open(filename, 'r')
|
8 |
+
docs = yaml.load_all(stream, Loader=yaml.Loader)
|
9 |
+
hparam_dict = dict()
|
10 |
+
for doc in docs:
|
11 |
+
for k, v in doc.items():
|
12 |
+
hparam_dict[k] = v
|
13 |
+
return hparam_dict
|
14 |
+
|
15 |
+
def merge_dict(user, default):
|
16 |
+
if isinstance(user, dict) and isinstance(default, dict):
|
17 |
+
for k, v in default.items():
|
18 |
+
if k not in user:
|
19 |
+
user[k] = v
|
20 |
+
else:
|
21 |
+
user[k] = merge_dict(user[k], v)
|
22 |
+
return user
|
23 |
+
|
24 |
+
|
25 |
+
class Dotdict(dict):
|
26 |
+
"""
|
27 |
+
a dictionary that supports dot notation
|
28 |
+
as well as dictionary access notation
|
29 |
+
usage: d = DotDict() or d = DotDict({'val1':'first'})
|
30 |
+
set attributes: d.val2 = 'second' or d['val2'] = 'second'
|
31 |
+
get attributes: d.val2 or d['val2']
|
32 |
+
"""
|
33 |
+
__getattr__ = dict.__getitem__
|
34 |
+
__setattr__ = dict.__setitem__
|
35 |
+
__delattr__ = dict.__delitem__
|
36 |
+
|
37 |
+
def __init__(self, dct=None):
|
38 |
+
dct = dict() if not dct else dct
|
39 |
+
for key, value in dct.items():
|
40 |
+
if hasattr(value, 'keys'):
|
41 |
+
value = Dotdict(value)
|
42 |
+
self[key] = value
|
43 |
+
|
44 |
+
|
45 |
+
class Hparam(Dotdict):
|
46 |
+
|
47 |
+
def __init__(self, file='config/config.yaml'):
|
48 |
+
super(Dotdict, self).__init__()
|
49 |
+
hp_dict = load_hparam(file)
|
50 |
+
hp_dotdict = Dotdict(hp_dict)
|
51 |
+
for k, v in hp_dotdict.items():
|
52 |
+
setattr(self, k, v)
|
53 |
+
|
54 |
+
__getattr__ = Dotdict.__getitem__
|
55 |
+
__setattr__ = Dotdict.__setitem__
|
56 |
+
__delattr__ = Dotdict.__delitem__
|
57 |
+
|
58 |
+
|
59 |
+
hparam = Hparam()
|
utils/kan.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
class KANLinear(torch.nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_features,
|
10 |
+
out_features,
|
11 |
+
grid_size=5,
|
12 |
+
spline_order=3,
|
13 |
+
scale_noise=0.1,
|
14 |
+
scale_base=1.0,
|
15 |
+
scale_spline=1.0,
|
16 |
+
enable_standalone_scale_spline=True,
|
17 |
+
base_activation=torch.nn.SiLU,
|
18 |
+
grid_eps=0.02,
|
19 |
+
grid_range=[-1, 1],
|
20 |
+
):
|
21 |
+
super(KANLinear, self).__init__()
|
22 |
+
self.in_features = in_features
|
23 |
+
self.out_features = out_features
|
24 |
+
self.grid_size = grid_size
|
25 |
+
self.spline_order = spline_order
|
26 |
+
|
27 |
+
h = (grid_range[1] - grid_range[0]) / grid_size
|
28 |
+
grid = (
|
29 |
+
(
|
30 |
+
torch.arange(-spline_order, grid_size + spline_order + 1) * h
|
31 |
+
+ grid_range[0]
|
32 |
+
)
|
33 |
+
.expand(in_features, -1)
|
34 |
+
.contiguous()
|
35 |
+
)
|
36 |
+
self.register_buffer("grid", grid)
|
37 |
+
|
38 |
+
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
|
39 |
+
self.spline_weight = torch.nn.Parameter(
|
40 |
+
torch.Tensor(out_features, in_features, grid_size + spline_order)
|
41 |
+
)
|
42 |
+
if enable_standalone_scale_spline:
|
43 |
+
self.spline_scaler = torch.nn.Parameter(
|
44 |
+
torch.Tensor(out_features, in_features)
|
45 |
+
)
|
46 |
+
|
47 |
+
self.scale_noise = scale_noise
|
48 |
+
self.scale_base = scale_base
|
49 |
+
self.scale_spline = scale_spline
|
50 |
+
self.enable_standalone_scale_spline = enable_standalone_scale_spline
|
51 |
+
self.base_activation = base_activation()
|
52 |
+
self.grid_eps = grid_eps
|
53 |
+
|
54 |
+
self.reset_parameters()
|
55 |
+
|
56 |
+
def reset_parameters(self):
|
57 |
+
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
|
58 |
+
with torch.no_grad():
|
59 |
+
noise = (
|
60 |
+
(
|
61 |
+
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
|
62 |
+
- 1 / 2
|
63 |
+
)
|
64 |
+
* self.scale_noise
|
65 |
+
/ self.grid_size
|
66 |
+
)
|
67 |
+
self.spline_weight.data.copy_(
|
68 |
+
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
|
69 |
+
* self.curve2coeff(
|
70 |
+
self.grid.T[self.spline_order : -self.spline_order],
|
71 |
+
noise,
|
72 |
+
)
|
73 |
+
)
|
74 |
+
if self.enable_standalone_scale_spline:
|
75 |
+
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
|
76 |
+
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
|
77 |
+
|
78 |
+
def b_splines(self, x: torch.Tensor):
|
79 |
+
"""
|
80 |
+
Compute the B-spline bases for the given input tensor.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
|
87 |
+
"""
|
88 |
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
89 |
+
|
90 |
+
grid: torch.Tensor = (
|
91 |
+
self.grid
|
92 |
+
) # (in_features, grid_size + 2 * spline_order + 1)
|
93 |
+
x = x.unsqueeze(-1)
|
94 |
+
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
|
95 |
+
for k in range(1, self.spline_order + 1):
|
96 |
+
bases = (
|
97 |
+
(x - grid[:, : -(k + 1)])
|
98 |
+
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
|
99 |
+
* bases[:, :, :-1]
|
100 |
+
) + (
|
101 |
+
(grid[:, k + 1 :] - x)
|
102 |
+
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
|
103 |
+
* bases[:, :, 1:]
|
104 |
+
)
|
105 |
+
|
106 |
+
assert bases.size() == (
|
107 |
+
x.size(0),
|
108 |
+
self.in_features,
|
109 |
+
self.grid_size + self.spline_order,
|
110 |
+
)
|
111 |
+
return bases.contiguous()
|
112 |
+
|
113 |
+
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
|
114 |
+
"""
|
115 |
+
Compute the coefficients of the curve that interpolates the given points.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
119 |
+
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
|
123 |
+
"""
|
124 |
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
125 |
+
assert y.size() == (x.size(0), self.in_features, self.out_features)
|
126 |
+
|
127 |
+
A = self.b_splines(x).transpose(
|
128 |
+
0, 1
|
129 |
+
) # (in_features, batch_size, grid_size + spline_order)
|
130 |
+
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
|
131 |
+
solution = torch.linalg.lstsq(
|
132 |
+
A, B
|
133 |
+
).solution # (in_features, grid_size + spline_order, out_features)
|
134 |
+
result = solution.permute(
|
135 |
+
2, 0, 1
|
136 |
+
) # (out_features, in_features, grid_size + spline_order)
|
137 |
+
|
138 |
+
assert result.size() == (
|
139 |
+
self.out_features,
|
140 |
+
self.in_features,
|
141 |
+
self.grid_size + self.spline_order,
|
142 |
+
)
|
143 |
+
return result.contiguous()
|
144 |
+
|
145 |
+
@property
|
146 |
+
def scaled_spline_weight(self):
|
147 |
+
return self.spline_weight * (
|
148 |
+
self.spline_scaler.unsqueeze(-1)
|
149 |
+
if self.enable_standalone_scale_spline
|
150 |
+
else 1.0
|
151 |
+
)
|
152 |
+
|
153 |
+
def forward(self, x: torch.Tensor):
|
154 |
+
assert x.size(-1) == self.in_features
|
155 |
+
original_shape = x.shape
|
156 |
+
x = x.view(-1, self.in_features)
|
157 |
+
|
158 |
+
base_output = F.linear(self.base_activation(x), self.base_weight)
|
159 |
+
spline_output = F.linear(
|
160 |
+
self.b_splines(x).view(x.size(0), -1),
|
161 |
+
self.scaled_spline_weight.view(self.out_features, -1),
|
162 |
+
)
|
163 |
+
output = base_output + spline_output
|
164 |
+
|
165 |
+
output = output.view(*original_shape[:-1], self.out_features)
|
166 |
+
return output
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def update_grid(self, x: torch.Tensor, margin=0.01):
|
170 |
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
171 |
+
batch = x.size(0)
|
172 |
+
|
173 |
+
splines = self.b_splines(x) # (batch, in, coeff)
|
174 |
+
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
|
175 |
+
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
|
176 |
+
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
|
177 |
+
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
|
178 |
+
unreduced_spline_output = unreduced_spline_output.permute(
|
179 |
+
1, 0, 2
|
180 |
+
) # (batch, in, out)
|
181 |
+
|
182 |
+
# sort each channel individually to collect data distribution
|
183 |
+
x_sorted = torch.sort(x, dim=0)[0]
|
184 |
+
grid_adaptive = x_sorted[
|
185 |
+
torch.linspace(
|
186 |
+
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
|
187 |
+
)
|
188 |
+
]
|
189 |
+
|
190 |
+
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
|
191 |
+
grid_uniform = (
|
192 |
+
torch.arange(
|
193 |
+
self.grid_size + 1, dtype=torch.float32, device=x.device
|
194 |
+
).unsqueeze(1)
|
195 |
+
* uniform_step
|
196 |
+
+ x_sorted[0]
|
197 |
+
- margin
|
198 |
+
)
|
199 |
+
|
200 |
+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
201 |
+
grid = torch.concatenate(
|
202 |
+
[
|
203 |
+
grid[:1]
|
204 |
+
- uniform_step
|
205 |
+
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
|
206 |
+
grid,
|
207 |
+
grid[-1:]
|
208 |
+
+ uniform_step
|
209 |
+
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
|
210 |
+
],
|
211 |
+
dim=0,
|
212 |
+
)
|
213 |
+
|
214 |
+
self.grid.copy_(grid.T)
|
215 |
+
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
|
216 |
+
|
217 |
+
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
218 |
+
"""
|
219 |
+
Compute the regularization loss.
|
220 |
+
|
221 |
+
This is a dumb simulation of the original L1 regularization as stated in the
|
222 |
+
paper, since the original one requires computing absolutes and entropy from the
|
223 |
+
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
|
224 |
+
behind the F.linear function if we want an memory efficient implementation.
|
225 |
+
|
226 |
+
The L1 regularization is now computed as mean absolute value of the spline
|
227 |
+
weights. The authors implementation also includes this term in addition to the
|
228 |
+
sample-based regularization.
|
229 |
+
"""
|
230 |
+
l1_fake = self.spline_weight.abs().mean(-1)
|
231 |
+
regularization_loss_activation = l1_fake.sum()
|
232 |
+
p = l1_fake / regularization_loss_activation
|
233 |
+
regularization_loss_entropy = -torch.sum(p * p.log())
|
234 |
+
return (
|
235 |
+
regularize_activation * regularization_loss_activation
|
236 |
+
+ regularize_entropy * regularization_loss_entropy
|
237 |
+
)
|
238 |
+
|
239 |
+
|
240 |
+
class KAN(torch.nn.Module):
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
layers_hidden,
|
244 |
+
grid_size=5,
|
245 |
+
spline_order=3,
|
246 |
+
scale_noise=0.1,
|
247 |
+
scale_base=1.0,
|
248 |
+
scale_spline=1.0,
|
249 |
+
base_activation=torch.nn.SiLU,
|
250 |
+
grid_eps=0.02,
|
251 |
+
grid_range=[-1, 1],
|
252 |
+
):
|
253 |
+
super(KAN, self).__init__()
|
254 |
+
self.grid_size = grid_size
|
255 |
+
self.spline_order = spline_order
|
256 |
+
|
257 |
+
self.layers = torch.nn.ModuleList()
|
258 |
+
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
|
259 |
+
self.layers.append(
|
260 |
+
KANLinear(
|
261 |
+
in_features,
|
262 |
+
out_features,
|
263 |
+
grid_size=grid_size,
|
264 |
+
spline_order=spline_order,
|
265 |
+
scale_noise=scale_noise,
|
266 |
+
scale_base=scale_base,
|
267 |
+
scale_spline=scale_spline,
|
268 |
+
base_activation=base_activation,
|
269 |
+
grid_eps=grid_eps,
|
270 |
+
grid_range=grid_range,
|
271 |
+
)
|
272 |
+
)
|
273 |
+
|
274 |
+
def forward(self, x: torch.Tensor, update_grid=False):
|
275 |
+
for layer in self.layers:
|
276 |
+
if update_grid:
|
277 |
+
layer.update_grid(x)
|
278 |
+
x = layer(x)
|
279 |
+
return x
|
280 |
+
|
281 |
+
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
282 |
+
return sum(
|
283 |
+
layer.regularization_loss(regularize_activation, regularize_entropy)
|
284 |
+
for layer in self.layers
|
285 |
+
)
|
utils/speech_embedder_net.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Wed Sep 5 20:58:34 2018
|
5 |
+
|
6 |
+
@author: harry
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from utils.hparam import hparam as hp
|
13 |
+
from utils.utils import get_centroids, get_cossim, calc_loss
|
14 |
+
from utils.kan import KANLinear
|
15 |
+
|
16 |
+
class SpeechEmbedder(nn.Module):
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
super(SpeechEmbedder, self).__init__()
|
20 |
+
self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
|
21 |
+
for name, param in self.LSTM_stack.named_parameters():
|
22 |
+
if 'bias' in name:
|
23 |
+
nn.init.constant_(param, 0.0)
|
24 |
+
elif 'weight' in name:
|
25 |
+
nn.init.xavier_normal_(param)
|
26 |
+
self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
|
30 |
+
#only use last frame
|
31 |
+
x = x[:,x.size(1)-1]
|
32 |
+
x = self.projection(x.float())
|
33 |
+
x = x / torch.norm(x, dim=1).unsqueeze(1)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class SpeechEmbedderGRU(nn.Module):
|
38 |
+
def __init__(self):
|
39 |
+
super(SpeechEmbedderGRU, self).__init__()
|
40 |
+
self.GRU_stack = nn.GRU(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
|
41 |
+
for name, param in self.GRU_stack.named_parameters():
|
42 |
+
if 'bias' in name:
|
43 |
+
nn.init.constant_(param, 0.0)
|
44 |
+
elif 'weight' in name:
|
45 |
+
nn.init.xavier_normal_(param)
|
46 |
+
self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x, _ = self.GRU_stack(x.float()) #(batch, frames, n_mels)
|
50 |
+
#only use last frame
|
51 |
+
x = x[:,x.size(1)-1]
|
52 |
+
x = self.projection(x.float())
|
53 |
+
x = x / torch.norm(x, dim=1).unsqueeze(1)
|
54 |
+
return x
|
55 |
+
|
56 |
+
class SpeechEmbedderKAN(nn.Module):
|
57 |
+
def __init__(self):
|
58 |
+
super(SpeechEmbedderKAN, self).__init__()
|
59 |
+
self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
|
60 |
+
for name, param in self.LSTM_stack.named_parameters():
|
61 |
+
if 'bias' in name:
|
62 |
+
nn.init.constant_(param, 0.0)
|
63 |
+
elif 'weight' in name:
|
64 |
+
nn.init.xavier_normal_(param)
|
65 |
+
self.projection = KANLinear(hp.model.hidden, hp.model.proj)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
|
69 |
+
#only use last frame
|
70 |
+
x = x[:,x.size(1)-1]
|
71 |
+
x = self.projection(x.float())
|
72 |
+
x = x / torch.norm(x, dim=1).unsqueeze(1)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
class SpeechEmbedderBidirectional(nn.Module):
|
78 |
+
def __init__(self):
|
79 |
+
super(SpeechEmbedderBidirectional, self).__init__()
|
80 |
+
self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True, bidirectional=True)
|
81 |
+
for name, param in self.LSTM_stack.named_parameters():
|
82 |
+
if 'bias' in name:
|
83 |
+
nn.init.constant_(param, 0.0)
|
84 |
+
elif 'weight' in name:
|
85 |
+
nn.init.xavier_normal_(param)
|
86 |
+
self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
|
90 |
+
#only use last frame
|
91 |
+
x = x[:, :, :hp.model.hidden]
|
92 |
+
|
93 |
+
x = x[:,x.size(1)-1]
|
94 |
+
x = self.projection(x.float())
|
95 |
+
x = x / torch.norm(x, dim=1).unsqueeze(1)
|
96 |
+
return x
|
97 |
+
|
98 |
+
class GE2ELoss(nn.Module):
|
99 |
+
|
100 |
+
def __init__(self, device):
|
101 |
+
super(GE2ELoss, self).__init__()
|
102 |
+
self.w = nn.Parameter(torch.tensor(10.0).to(device), requires_grad=True)
|
103 |
+
self.b = nn.Parameter(torch.tensor(-5.0).to(device), requires_grad=True)
|
104 |
+
self.device = device
|
105 |
+
|
106 |
+
def forward(self, embeddings):
|
107 |
+
torch.clamp(self.w, 1e-6)
|
108 |
+
centroids = get_centroids(embeddings)
|
109 |
+
cossim = get_cossim(embeddings, centroids)
|
110 |
+
sim_matrix = self.w*cossim.to(self.device) + self.b
|
111 |
+
loss, _ = calc_loss(sim_matrix)
|
112 |
+
return loss
|
utils/utils.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Thu Sep 20 16:56:19 2018
|
5 |
+
|
6 |
+
@author: harry
|
7 |
+
"""
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.autograd as grad
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from utils.hparam import hparam as hp
|
15 |
+
|
16 |
+
def get_centroids_prior(embeddings):
|
17 |
+
centroids = []
|
18 |
+
for speaker in embeddings:
|
19 |
+
centroid = 0
|
20 |
+
for utterance in speaker:
|
21 |
+
centroid = centroid + utterance
|
22 |
+
centroid = centroid/len(speaker)
|
23 |
+
centroids.append(centroid)
|
24 |
+
centroids = torch.stack(centroids)
|
25 |
+
return centroids
|
26 |
+
|
27 |
+
def get_centroids(embeddings):
|
28 |
+
centroids = embeddings.mean(dim=1)
|
29 |
+
return centroids
|
30 |
+
|
31 |
+
def get_centroid(embeddings, speaker_num, utterance_num):
|
32 |
+
centroid = 0
|
33 |
+
for utterance_id, utterance in enumerate(embeddings[speaker_num]):
|
34 |
+
if utterance_id == utterance_num:
|
35 |
+
continue
|
36 |
+
centroid = centroid + utterance
|
37 |
+
centroid = centroid/(len(embeddings[speaker_num])-1)
|
38 |
+
return centroid
|
39 |
+
|
40 |
+
def get_utterance_centroids(embeddings):
|
41 |
+
"""
|
42 |
+
Returns the centroids for each utterance of a speaker, where
|
43 |
+
the utterance centroid is the speaker centroid without considering
|
44 |
+
this utterance
|
45 |
+
|
46 |
+
Shape of embeddings should be:
|
47 |
+
(speaker_ct, utterance_per_speaker_ct, embedding_size)
|
48 |
+
"""
|
49 |
+
sum_centroids = embeddings.sum(dim=1)
|
50 |
+
# we want to subtract out each utterance, prior to calculating the
|
51 |
+
# the utterance centroid
|
52 |
+
sum_centroids = sum_centroids.reshape(
|
53 |
+
sum_centroids.shape[0], 1, sum_centroids.shape[-1]
|
54 |
+
)
|
55 |
+
# we want the mean but not including the utterance itself, so -1
|
56 |
+
num_utterances = embeddings.shape[1] - 1
|
57 |
+
centroids = (sum_centroids - embeddings) / num_utterances
|
58 |
+
return centroids
|
59 |
+
|
60 |
+
def get_cossim_prior(embeddings, centroids):
|
61 |
+
# Calculates cosine similarity matrix. Requires (N, M, feature) input
|
62 |
+
cossim = torch.zeros(embeddings.size(0),embeddings.size(1),centroids.size(0))
|
63 |
+
for speaker_num, speaker in enumerate(embeddings):
|
64 |
+
for utterance_num, utterance in enumerate(speaker):
|
65 |
+
for centroid_num, centroid in enumerate(centroids):
|
66 |
+
if speaker_num == centroid_num:
|
67 |
+
centroid = get_centroid(embeddings, speaker_num, utterance_num)
|
68 |
+
output = F.cosine_similarity(utterance,centroid,dim=0)+1e-6
|
69 |
+
cossim[speaker_num][utterance_num][centroid_num] = output
|
70 |
+
return cossim
|
71 |
+
|
72 |
+
def get_cossim(embeddings, centroids):
|
73 |
+
# number of utterances per speaker
|
74 |
+
num_utterances = embeddings.shape[1]
|
75 |
+
utterance_centroids = get_utterance_centroids(embeddings)
|
76 |
+
|
77 |
+
# flatten the embeddings and utterance centroids to just utterance,
|
78 |
+
# so we can do cosine similarity
|
79 |
+
utterance_centroids_flat = utterance_centroids.view(
|
80 |
+
utterance_centroids.shape[0] * utterance_centroids.shape[1],
|
81 |
+
-1
|
82 |
+
)
|
83 |
+
embeddings_flat = embeddings.view(
|
84 |
+
embeddings.shape[0] * num_utterances,
|
85 |
+
-1
|
86 |
+
)
|
87 |
+
# the cosine distance between utterance and the associated centroids
|
88 |
+
# for that utterance
|
89 |
+
# this is each speaker's utterances against his own centroid, but each
|
90 |
+
# comparison centroid has the current utterance removed
|
91 |
+
cos_same = F.cosine_similarity(embeddings_flat, utterance_centroids_flat)
|
92 |
+
|
93 |
+
# now we get the cosine distance between each utterance and the other speakers'
|
94 |
+
# centroids
|
95 |
+
# to do so requires comparing each utterance to each centroid. To keep the
|
96 |
+
# operation fast, we vectorize by using matrices L (embeddings) and
|
97 |
+
# R (centroids) where L has each utterance repeated sequentially for all
|
98 |
+
# comparisons and R has the entire centroids frame repeated for each utterance
|
99 |
+
centroids_expand = centroids.repeat((num_utterances * embeddings.shape[0], 1))
|
100 |
+
embeddings_expand = embeddings_flat.unsqueeze(1).repeat(1, embeddings.shape[0], 1)
|
101 |
+
embeddings_expand = embeddings_expand.view(
|
102 |
+
embeddings_expand.shape[0] * embeddings_expand.shape[1],
|
103 |
+
embeddings_expand.shape[-1]
|
104 |
+
)
|
105 |
+
cos_diff = F.cosine_similarity(embeddings_expand, centroids_expand)
|
106 |
+
cos_diff = cos_diff.view(
|
107 |
+
embeddings.size(0),
|
108 |
+
num_utterances,
|
109 |
+
centroids.size(0)
|
110 |
+
)
|
111 |
+
# assign the cosine distance for same speakers to the proper idx
|
112 |
+
same_idx = list(range(embeddings.size(0)))
|
113 |
+
cos_diff[same_idx, :, same_idx] = cos_same.view(embeddings.shape[0], num_utterances)
|
114 |
+
cos_diff = cos_diff + 1e-6
|
115 |
+
return cos_diff
|
116 |
+
|
117 |
+
def calc_loss_prior(sim_matrix):
|
118 |
+
# Calculates loss from (N, M, K) similarity matrix
|
119 |
+
per_embedding_loss = torch.zeros(sim_matrix.size(0), sim_matrix.size(1))
|
120 |
+
for j in range(len(sim_matrix)):
|
121 |
+
for i in range(sim_matrix.size(1)):
|
122 |
+
per_embedding_loss[j][i] = -(sim_matrix[j][i][j] - ((torch.exp(sim_matrix[j][i]).sum()+1e-6).log_()))
|
123 |
+
loss = per_embedding_loss.sum()
|
124 |
+
return loss, per_embedding_loss
|
125 |
+
|
126 |
+
def calc_loss(sim_matrix):
|
127 |
+
same_idx = list(range(sim_matrix.size(0)))
|
128 |
+
pos = sim_matrix[same_idx, :, same_idx]
|
129 |
+
neg = (torch.exp(sim_matrix).sum(dim=2) + 1e-6).log_()
|
130 |
+
per_embedding_loss = -1 * (pos - neg)
|
131 |
+
loss = per_embedding_loss.sum()
|
132 |
+
return loss, per_embedding_loss
|
133 |
+
|
134 |
+
def normalize_0_1(values, max_value, min_value):
|
135 |
+
normalized = np.clip((values - min_value) / (max_value - min_value), 0, 1)
|
136 |
+
return normalized
|
137 |
+
|
138 |
+
def mfccs_and_spec(wav_file, wav_process = False, calc_mfccs=False, calc_mag_db=False):
|
139 |
+
sound_file, _ = librosa.core.load(wav_file, sr=hp.data.sr)
|
140 |
+
window_length = int(hp.data.window*hp.data.sr)
|
141 |
+
hop_length = int(hp.data.hop*hp.data.sr)
|
142 |
+
duration = hp.data.tisv_frame * hp.data.hop + hp.data.window
|
143 |
+
|
144 |
+
# Cut silence and fix length
|
145 |
+
if wav_process == True:
|
146 |
+
sound_file, index = librosa.effects.trim(sound_file, frame_length=window_length, hop_length=hop_length)
|
147 |
+
length = int(hp.data.sr * duration)
|
148 |
+
sound_file = librosa.util.fix_length(sound_file, length)
|
149 |
+
|
150 |
+
spec = librosa.stft(sound_file, n_fft=hp.data.nfft, hop_length=hop_length, win_length=window_length)
|
151 |
+
mag_spec = np.abs(spec)
|
152 |
+
|
153 |
+
mel_basis = librosa.filters.mel(hp.data.sr, hp.data.nfft, n_mels=hp.data.nmels)
|
154 |
+
mel_spec = np.dot(mel_basis, mag_spec)
|
155 |
+
|
156 |
+
mag_db = librosa.amplitude_to_db(mag_spec)
|
157 |
+
#db mel spectrogram
|
158 |
+
mel_db = librosa.amplitude_to_db(mel_spec).T
|
159 |
+
|
160 |
+
mfccs = None
|
161 |
+
if calc_mfccs:
|
162 |
+
mfccs = np.dot(librosa.filters.dct(40, mel_db.shape[0]), mel_db).T
|
163 |
+
|
164 |
+
return mfccs, mel_db, mag_db
|
165 |
+
|
166 |
+
if __name__ == "__main__":
|
167 |
+
w = grad.Variable(torch.tensor(1.0))
|
168 |
+
b = grad.Variable(torch.tensor(0.0))
|
169 |
+
embeddings = torch.tensor([[0,1,0],[0,0,1], [0,1,0], [0,1,0], [1,0,0], [1,0,0]]).to(torch.float).reshape(3,2,3)
|
170 |
+
centroids = get_centroids(embeddings)
|
171 |
+
cossim = get_cossim(embeddings, centroids)
|
172 |
+
sim_matrix = w*cossim + b
|
173 |
+
loss, per_embedding_loss = calc_loss(sim_matrix)
|