gujarati-tisv / app.py
Irsh Vijayvargia
Update Description
481452e
import torch
import librosa
import numpy as np
import os
import webrtcvad
import wave
import contextlib
import gradio as gr
from utils.VAD_segments import *
from utils.hparam import hparam as hp
from utils.speech_embedder_net import *
from utils.evaluation import *
def read_wave(audio_data):
"""Reads audio data and returns (PCM audio data, sample rate).
Assumes the input is a tuple (sample_rate, numpy_array).
If the sample rate is unsupported, resamples to 16000 Hz.
"""
sample_rate, data = audio_data
# Ensure data is in the correct shape
assert len(data.shape) == 1, "Audio data must be a 1D array"
# Convert to floating point if necessary
if not np.issubdtype(data.dtype, np.floating):
data = data.astype(np.float32) / np.iinfo(data.dtype).max
# Supported sample rates
supported_sample_rates = (8000, 16000, 32000, 48000)
# If sample rate is not supported, resample to 16000 Hz
if sample_rate not in supported_sample_rates:
data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
# Convert numpy array to PCM format
pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes()
return data, pcm_data
def VAD_chunk(aggressiveness, data):
audio, byte_audio = read_wave(data)
vad = webrtcvad.Vad(int(aggressiveness))
frames = frame_generator(20, byte_audio, hp.data.sr)
frames = list(frames)
times = vad_collector(hp.data.sr, 20, 200, vad, frames)
speech_times = []
speech_segs = []
for i, time in enumerate(times):
start = np.round(time[0],decimals=2)
end = np.round(time[1],decimals=2)
j = start
while j + .4 < end:
end_j = np.round(j+.4,decimals=2)
speech_times.append((j, end_j))
speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])
j = end_j
else:
speech_times.append((j, end))
speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])
return speech_times, speech_segs
def get_embedding(data, embedder_net, device, n_threshold=-1):
times, segs = VAD_chunk(0, data)
if not segs:
print(f'No voice activity detected')
return None
concat_seg = concat_segs(times, segs)
if not concat_seg:
print(f'No concatenated segments')
return None
STFT_frames = get_STFTs(concat_seg)
if not STFT_frames:
#print(f'No STFT frames')
return None
STFT_frames = np.stack(STFT_frames, axis=2)
STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)
with torch.no_grad():
embeddings = embedder_net(STFT_frames)
embeddings = embeddings[:n_threshold, :]
avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()
return avg_embedding
model_path = "./speech_id_checkpoint/saved_02.model"
embedder_net = SpeechEmbedder()
embedder_net.load_state_dict(torch.load(model_path))
embedder_net.eval()
def process_audio(audio1, audio2, threshold):
e1 = get_embedding(audio1, embedder_net, torch.device("cpu"))
if(e1 is None):
return "No Voice Detected in file 1"
e2 = get_embedding(audio2, embedder_net, torch.device("cpu"))
if(e2 is None):
return "No Voice Detected in file 2"
cosi = cosine_similarity(e1, e2)
if(cosi > threshold):
return f"Same Speaker"
else:
return f"Different Speaker"
# Define the Gradio interface
def gradio_interface(audio1, audio2, threshold):
output_text = process_audio(audio1, audio2, threshold)
return output_text
description = """
<p>
<center>
This is an LSTM based Speaker Embedding Model trained using <a href="https://arxiv.org/abs/1710.10467">GE2E loss</a> on the <a href="https://openslr.org/78/">Gujarati OpenSLR dataset</a>.
<img src="https://huggingface.co/spaces/1rsh/gujarati-tisv/resolve/main/img/gujarati-text.png" alt="Gujarati" width="250"/>
</center>
</p>
"""
# Create the Gradio interface with microphone inputs
iface = gr.Interface(
fn=gradio_interface,
inputs=[gr.Audio("microphone", type="numpy", label="Audio File 1"),
gr.Audio("microphone", type="numpy", label="Audio File 2"),
gr.Slider(0.0, 1.0, value=0.85, step=0.01, label="Threshold")
],
outputs="text",
title="ગુજરાતી Text Independent Speaker Verification",
description=description
)
# Launch the interface
iface.launch(share=False)