File size: 3,325 Bytes
a6b8100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright © 2024 Roblox Corporation

"""
This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model. 
"""

import torch
import librosa
import numpy as np
import argparse
from transformers import WavLMForSequenceClassification


def feature_extract_simple(
    wav,
    sr=16_000,
    win_len=15.0,
    win_stride=15.0,
    do_normalize=False,
):
    """simple feature extraction for wavLM
    Parameters
    ----------
    wav : str or array-like
        path to the wav file, or array-like
    sr : int, optional
        sample rate, by default 16_000
    win_len : float, optional
        window length, by default 15.0
    win_stride : float, optional
        window stride, by default 15.0
    do_normalize: bool, optional
        whether to normalize the input, by default False.
    Returns
    -------
    np.ndarray
        batched input to wavLM
    """
    if type(wav) == str:
        signal, _ = librosa.core.load(wav, sr=sr)
    else:
        try:
            signal = np.array(wav).squeeze()
        except Exception as e:
            print(e)
            raise RuntimeError
    batched_input = []
    stride = int(win_stride * sr)
    l = int(win_len * sr)
    if len(signal) / sr > win_len:
        for i in range(0, len(signal), stride):
            if i + int(win_len * sr) > len(signal):
                # padding the last chunk to make it the same length as others
                chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
            else:
                chunked = signal[i : i + l]
            if do_normalize:
                chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
            batched_input.append(chunked)
            if i + int(win_len * sr) > len(signal):
                break
    else:
        if do_normalize:
            signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
        batched_input.append(signal)
    return np.stack(batched_input)  # [N, T]


def infer(model, inputs):
    output = model(inputs)
    probs = torch.sigmoid(torch.Tensor(output.logits))
    return probs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--audio_file",
        type=str,
        help="File to run inference",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default="roblox/voice-safety-classifier",
        help="checkpoint file of model",
    )
    args = parser.parse_args()
    labels_name_list = [
        "Profanity",
        "DatingAndSexting",
        "Racist",
        "Bullying",
        "Other",
        "NoViolation",
    ]
    # Model is trained on only 16kHz audio
    audio, _ = librosa.core.load(args.audio_file, sr=16000)
    input_np = feature_extract_simple(audio, sr=16000)
    input_pt = torch.Tensor(input_np)
    model = WavLMForSequenceClassification.from_pretrained(
        args.model_path, num_labels=len(labels_name_list)
    )
    probs = infer(model, input_pt)
    probs = probs.reshape(-1, 6).detach().tolist()
    print(f"Probabilities for {args.audio_file} is:")
    for chunk_idx in range(len(probs)):
        print(f"\nSegment {chunk_idx}:")
        for label_idx, label in enumerate(labels_name_list):
            print(f"{label} : {probs[chunk_idx][label_idx]}")