voice-safety-classifier / inference.py
josephliu-roblox's picture
Add README and supporting code
a6b8100
raw
history blame
No virus
3.33 kB
# Copyright © 2024 Roblox Corporation
"""
This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model.
"""
import torch
import librosa
import numpy as np
import argparse
from transformers import WavLMForSequenceClassification
def feature_extract_simple(
wav,
sr=16_000,
win_len=15.0,
win_stride=15.0,
do_normalize=False,
):
"""simple feature extraction for wavLM
Parameters
----------
wav : str or array-like
path to the wav file, or array-like
sr : int, optional
sample rate, by default 16_000
win_len : float, optional
window length, by default 15.0
win_stride : float, optional
window stride, by default 15.0
do_normalize: bool, optional
whether to normalize the input, by default False.
Returns
-------
np.ndarray
batched input to wavLM
"""
if type(wav) == str:
signal, _ = librosa.core.load(wav, sr=sr)
else:
try:
signal = np.array(wav).squeeze()
except Exception as e:
print(e)
raise RuntimeError
batched_input = []
stride = int(win_stride * sr)
l = int(win_len * sr)
if len(signal) / sr > win_len:
for i in range(0, len(signal), stride):
if i + int(win_len * sr) > len(signal):
# padding the last chunk to make it the same length as others
chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
else:
chunked = signal[i : i + l]
if do_normalize:
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
batched_input.append(chunked)
if i + int(win_len * sr) > len(signal):
break
else:
if do_normalize:
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
batched_input.append(signal)
return np.stack(batched_input) # [N, T]
def infer(model, inputs):
output = model(inputs)
probs = torch.sigmoid(torch.Tensor(output.logits))
return probs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_file",
type=str,
help="File to run inference",
)
parser.add_argument(
"--model_path",
type=str,
default="roblox/voice-safety-classifier",
help="checkpoint file of model",
)
args = parser.parse_args()
labels_name_list = [
"Profanity",
"DatingAndSexting",
"Racist",
"Bullying",
"Other",
"NoViolation",
]
# Model is trained on only 16kHz audio
audio, _ = librosa.core.load(args.audio_file, sr=16000)
input_np = feature_extract_simple(audio, sr=16000)
input_pt = torch.Tensor(input_np)
model = WavLMForSequenceClassification.from_pretrained(
args.model_path, num_labels=len(labels_name_list)
)
probs = infer(model, input_pt)
probs = probs.reshape(-1, 6).detach().tolist()
print(f"Probabilities for {args.audio_file} is:")
for chunk_idx in range(len(probs)):
print(f"\nSegment {chunk_idx}:")
for label_idx, label in enumerate(labels_name_list):
print(f"{label} : {probs[chunk_idx][label_idx]}")