File size: 1,723 Bytes
580eae2
c9f5661
 
 
8ac9c32
c9f5661
 
580eae2
c9f5661
ae5b658
c9f5661
580eae2
c9f5661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd662cd
 
 
 
 
c9f5661
 
 
bfb5dda
c9f5661
 
 
 
 
 
 
 
be5a80e
526feee
c9f5661
b4d320e
c9f5661
be5a80e
580eae2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import torch
# from lr_ed.model import CNNEmotinoalClassifier
import torchaudio
# import IPython.display as ipd
from torch import nn
from model import CNNEmotinoalClassifier

model = CNNEmotinoalClassifier()
model.load_state_dict(torch.load('./cnn_class_17.pt', map_location=torch.device('cpu')))
model.eval()

to_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate= 22050,
    n_fft = 1024,
    hop_length = 512,
    n_mels=64
)

def _get_right_pad(target_waveform, waveform):
    target_waveform = target_waveform
    waveform_samples_number = waveform.shape[1]
    if waveform_samples_number < target_waveform:
        right_pad = target_waveform - waveform_samples_number
        padding_touple = (0, right_pad)
        waveform_padded = nn.functional.pad(waveform, padding_touple)
    else:
        waveform_padded = waveform
    return waveform_padded

def get_probs(mic=None, file=None):
    if mic is not None:
        audio = mic
    elif file is not None:
        audio = file
    emotions = ['happy', 'angry', 'sad', 'neutral', 'surprised', 'fear']
    emotions = sorted(emotions)
    
    waveform, sr = torchaudio.load(audio)
    waveform = _get_right_pad(400384, waveform)
    input_x = to_melspec(waveform)
    input_x = torch.unsqueeze(input_x, dim=1)

    probs = model(input_x)
    prediction = emotions[probs.argmax(dim=1).item()]
    return dict(zip(emotions, list(map(float, probs[0]))))

# mic = gr.Audio(sources="microphone", type="numpy", label="Speak here...")
input = gr.Audio(sources="microphone", type="filepath")
label = gr.Label()
examples = ['Akzhol_happy.wav']

iface = gr.Interface(fn=get_probs, inputs=input, outputs=label, examples=examples)
iface.launch()