{
"cells": [
{
"cell_type": "markdown",
"id": "23237138-936a-44b4-9eb6-f16045d2c91d",
"metadata": {},
"source": [
"### **Gradio Demo | LSTM Speaker Embedding Model for Gujarati Speaker Verification**\n",
"****\n",
"**Author:** Irsh Vijay
\n",
"**Organization**: Wadhwani Institute for Artificial Intelligence
\n",
"****\n",
"This notebook has the required code to run a gradio demo."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1d2cfd8b-9498-4236-9d32-718e9e0597cb",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import librosa\n",
"import numpy as np\n",
"import os\n",
"import webrtcvad\n",
"import wave\n",
"import contextlib\n",
"\n",
"from utils.VAD_segments import *\n",
"from utils.hparam import hparam as hp\n",
"from utils.speech_embedder_net import *\n",
"from utils.evaluation import *"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3e9e1006-83d2-4492-a210-26b2c3717cd5",
"metadata": {},
"outputs": [],
"source": [
"def read_wave(audio_data):\n",
" \"\"\"Reads audio data and returns (PCM audio data, sample rate).\n",
" Assumes the input is a tuple (sample_rate, numpy_array).\n",
" If the sample rate is unsupported, resamples to 16000 Hz.\n",
" \"\"\"\n",
" sample_rate, data = audio_data\n",
"\n",
" # Ensure data is in the correct shape\n",
" assert len(data.shape) == 1, \"Audio data must be a 1D array\"\n",
"\n",
" # Convert to floating point if necessary\n",
" if not np.issubdtype(data.dtype, np.floating):\n",
" data = data.astype(np.float32) / np.iinfo(data.dtype).max\n",
" \n",
" # Supported sample rates\n",
" supported_sample_rates = (8000, 16000, 32000, 48000)\n",
" \n",
" # If sample rate is not supported, resample to 16000 Hz\n",
" if sample_rate not in supported_sample_rates:\n",
" data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000)\n",
" sample_rate = 16000\n",
" \n",
" # Convert numpy array to PCM format\n",
" pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes()\n",
"\n",
" return data, pcm_data"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0b56a2fc-83c3-4b36-95b8-5f1b656150ed",
"metadata": {},
"outputs": [],
"source": [
"def VAD_chunk(aggressiveness, data):\n",
" audio, byte_audio = read_wave(data)\n",
" vad = webrtcvad.Vad(int(aggressiveness))\n",
" frames = frame_generator(20, byte_audio, hp.data.sr)\n",
" frames = list(frames)\n",
" times = vad_collector(hp.data.sr, 20, 200, vad, frames)\n",
" speech_times = []\n",
" speech_segs = []\n",
" for i, time in enumerate(times):\n",
" start = np.round(time[0],decimals=2)\n",
" end = np.round(time[1],decimals=2)\n",
" j = start\n",
" while j + .4 < end:\n",
" end_j = np.round(j+.4,decimals=2)\n",
" speech_times.append((j, end_j))\n",
" speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])\n",
" j = end_j\n",
" else:\n",
" speech_times.append((j, end))\n",
" speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])\n",
" return speech_times, speech_segs"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "72f257cf-7d3f-4ec5-944a-57779ba377e6",
"metadata": {},
"outputs": [],
"source": [
"def get_embedding(data, embedder_net, device, n_threshold=-1):\n",
" times, segs = VAD_chunk(0, data)\n",
" if not segs:\n",
" print(f'No voice activity detected')\n",
" return None\n",
" concat_seg = concat_segs(times, segs)\n",
" if not concat_seg:\n",
" print(f'No concatenated segments')\n",
" return None\n",
" STFT_frames = get_STFTs(concat_seg)\n",
" if not STFT_frames:\n",
" #print(f'No STFT frames')\n",
" return None\n",
" STFT_frames = np.stack(STFT_frames, axis=2)\n",
" STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)\n",
"\n",
" with torch.no_grad():\n",
" embeddings = embedder_net(STFT_frames)\n",
" embeddings = embeddings[:n_threshold, :]\n",
" \n",
" avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()\n",
" return avg_embedding"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "200df766-407d-4367-b0fb-7a6118653731",
"metadata": {},
"outputs": [],
"source": [
"model_path = \"./speech_id_checkpoint/saved_01.model\""
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "db7613e6-67a8-4920-a999-caca4a0de360",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SpeechEmbedder(\n",
" (LSTM_stack): LSTM(40, 768, num_layers=3, batch_first=True)\n",
" (projection): Linear(in_features=768, out_features=256, bias=True)\n",
")"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n",
"\n",
"embedder_net = SpeechEmbedder().to(device)\n",
"embedder_net.load_state_dict(torch.load(model_path, map_location=device))\n",
"embedder_net.eval()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "8a7dd9bd-7b40-41f9-8e2f-d68be18f2111",
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "bd6c073d-eab8-4ae6-8ba6-d90a0ec54c0e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on local URL: http://127.0.0.1:7868\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"