AlienKevin commited on
Commit
b256b6f
1 Parent(s): 7741af3

Add app.py and model

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import jyutping
3
+ from whisper_audio_classifier import WhisperAudioClassifier
4
+ import librosa
5
+ from transformers import WhisperFeatureExtractor
6
+
7
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(f"alvanlii/whisper-small-cantonese")
8
+ feature_extractor.chunk_length = 3
9
+
10
+ # Instantiate the model
11
+ device = torch.device("mps")
12
+ model = WhisperAudioClassifier().to(device)
13
+
14
+ # Load the state dict
15
+ state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth")
16
+
17
+ # Load the state dict into the model
18
+ model.load_state_dict(state_dict)
19
+
20
+ # Set the model to evaluation mode
21
+ model.eval()
22
+
23
+ def predict(audio):
24
+ features = feature_extractor(audio, sampling_rate=16000)
25
+ with torch.no_grad():
26
+ inputs = torch.from_numpy(features['input_features'][0]).to(device)
27
+ inputs = inputs.unsqueeze(0) # Add extra batch dimension in front
28
+ outs = model(inputs)
29
+ return [torch.softmax(tensor.squeeze(), dim=0).tolist() for tensor in outs]
30
+
31
+ import gradio as gr
32
+ import numpy as np
33
+
34
+ def rank_initials(preds, k=3):
35
+ ranked = sorted([((jyutping.inflate_initial(i) if jyutping.inflate_initial(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
36
+ return dict(ranked[:k])
37
+
38
+ def rank_nucli(preds, k=3):
39
+ ranked = sorted([((jyutping.inflate_nucleus(i) if jyutping.inflate_nucleus(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
40
+ return dict(ranked[:k])
41
+
42
+ def rank_codas(preds, k=3):
43
+ ranked = sorted([((jyutping.inflate_coda(i) if jyutping.inflate_coda(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
44
+ return dict(ranked[:k])
45
+
46
+ def rank_tones(preds, k=3):
47
+ ranked = sorted([(str(i + 1), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
48
+ return dict(ranked[:k])
49
+
50
+ def classify_audio(audio):
51
+ sampling_rate, audio = audio
52
+ audio = audio.astype(np.float32)
53
+ audio /= np.max(np.abs(audio))
54
+ audio_resampled = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
55
+ preds = predict(torch.from_numpy(audio_resampled))
56
+ return [
57
+ rank_initials(preds[0]),
58
+ rank_nucli(preds[1]),
59
+ rank_codas(preds[2]),
60
+ rank_tones(preds[3]),
61
+ rank_initials(preds[4]),
62
+ rank_nucli(preds[5]),
63
+ rank_codas(preds[6]),
64
+ rank_tones(preds[7]),
65
+ ]
66
+
67
+ with gr.Blocks() as demo:
68
+ with gr.Row():
69
+ inputs = gr.Audio(source="microphone", type="numpy", label="Input Audio")
70
+ submit_btn = gr.Button("Submit")
71
+
72
+ with gr.Row():
73
+ with gr.Column():
74
+ outputs_left = [
75
+ gr.Label(label="Initial 1"),
76
+ gr.Label(label="Nucleus 1"),
77
+ gr.Label(label="Coda 1"),
78
+ gr.Label(label="Tone 1"),
79
+ ]
80
+
81
+ with gr.Column():
82
+ outputs_right = [
83
+ gr.Label(label="Initial 2"),
84
+ gr.Label(label="Nucleus 2"),
85
+ gr.Label(label="Coda 2"),
86
+ gr.Label(label="Tone 2"),
87
+ ]
88
+
89
+ submit_btn.click(fn=classify_audio, inputs=inputs, outputs=outputs_left+outputs_right)
90
+
91
+ demo.launch()
jyutping.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extract_tone(syllable: str) -> int:
2
+ return int(syllable[-1]) - 1
3
+ jyutping_initials = ['∅', 'ng', 'gw', 'kw', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'w', 'z', 'c', 's', 'j']
4
+
5
+ def extract_initial(syllable: str) -> (int, str):
6
+ for i, initial in enumerate(jyutping_initials):
7
+ if syllable.startswith(initial):
8
+ return (i, initial)
9
+ return (0, '')
10
+
11
+ def inflate_initial(initial: int) -> str:
12
+ return jyutping_initials[initial] if initial != 0 else ''
13
+ jyutping_nuclei = ['∅', 'aa', 'yu', 'eo', 'oe', 'a', 'i', 'u', 'e', 'o']
14
+
15
+ def extract_nucleus(syllable: str, initial: str) -> int:
16
+ syllable = syllable[len(initial):]
17
+ for i, nucleus in enumerate(jyutping_nuclei):
18
+ if syllable.startswith(nucleus):
19
+ return (i, nucleus)
20
+ return (0, '')
21
+
22
+ def inflate_nucleus(nucleus: int) -> str:
23
+ return jyutping_nuclei[nucleus] if nucleus != 0 else ''
24
+ jyutping_codas = ['∅', 'ng', 'p', 't', 'k', 'm', 'n', 'i', 'u']
25
+
26
+ def extract_coda(syllable: str, initial: str, nucleus: str) -> int:
27
+ syllable = syllable[len(initial) + len(nucleus):]
28
+ for i, coda in enumerate(jyutping_codas):
29
+ if syllable.startswith(coda):
30
+ return (i, coda)
31
+ return (0, '')
32
+
33
+ def inflate_coda(coda: int) -> str:
34
+ return jyutping_codas[coda] if coda != 0 else ''
35
+ syllable = 'neoi5'
36
+
37
+ initial_int, initial = extract_initial(syllable)
38
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
39
+ coda_int, coda = extract_coda(syllable, initial, nucleus)
40
+
41
+ assert(initial == 'n' and initial_int == 10)
42
+ assert(nucleus == 'eo' and nucleus_int == 3)
43
+ assert(coda == 'i' and coda_int == 7)
44
+ syllable = 'gwok3'
45
+
46
+ initial_int, initial = extract_initial(syllable)
47
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
48
+ coda_int, coda = extract_coda(syllable, initial, nucleus)
49
+
50
+ assert(initial == 'gw' and initial_int == 2)
51
+ assert(nucleus == 'o' and nucleus_int == 9)
52
+ assert(coda == 'k' and coda_int == 4)
53
+ syllable = 'oi3'
54
+
55
+ initial_int, initial = extract_initial(syllable)
56
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
57
+ coda_int, coda = extract_coda(syllable, initial, nucleus)
58
+
59
+ assert(initial == '' and initial_int == 0)
60
+ assert(nucleus == 'o' and nucleus_int == 9)
61
+ assert(coda == 'i' and coda_int == 7)
62
+ syllable = 'ng4'
63
+
64
+ initial_int, initial = extract_initial(syllable)
65
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
66
+ coda_int, coda = extract_coda(syllable, initial, nucleus)
67
+
68
+ assert(initial == 'ng' and initial_int == 1)
69
+ assert(nucleus == '' and nucleus_int == 0)
70
+ assert(coda == '' and coda_int == 0)
71
+ syllable = 'ngo5'
72
+
73
+ initial_int, initial = extract_initial(syllable)
74
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
75
+ coda_int, coda = extract_coda(syllable, initial, nucleus)
76
+
77
+ assert(initial == 'ng' and initial_int == 1)
78
+ assert(nucleus == 'o' and nucleus_int == 9)
79
+ assert(coda == '' and coda_int == 0)
80
+ syllable = 'a3'
81
+
82
+ initial_int, initial = extract_initial(syllable)
83
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
84
+ coda_int, coda = extract_coda(syllable, initial, nucleus)
85
+
86
+ assert(initial == '' and initial_int == 0)
87
+ assert(nucleus == 'a' and nucleus_int == 5)
88
+ assert(coda == '' and coda_int == 0)
89
+ syllable = 'aa3'
90
+
91
+ initial_int, initial = extract_initial(syllable)
92
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
93
+ coda_int, coda = extract_coda(syllable, initial, nucleus)
94
+
95
+ assert(initial == '' and initial_int == 0)
96
+ assert(nucleus == 'aa' and nucleus_int == 1)
97
+ assert(coda == '' and coda_int == 0)
98
+ syllable = 'ngaang6'
99
+
100
+ initial_int, initial = extract_initial(syllable)
101
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
102
+ coda_int, coda = extract_coda(syllable, initial, nucleus)
103
+
104
+ assert(initial == 'ng' and initial_int == 1)
105
+ assert(nucleus == 'aa' and nucleus_int == 1)
106
+ assert(coda == 'ng' and coda_int == 1)
107
+ def extract_jyutping(syllable: str) -> (int, int, int, int):
108
+ initial_int, initial = extract_initial(syllable)
109
+ nucleus_int, nucleus = extract_nucleus(syllable, initial)
110
+ coda_int, _ = extract_coda(syllable, initial, nucleus)
111
+ tone = extract_tone(syllable)
112
+ return (initial_int, nucleus_int, coda_int, tone)
113
+
114
+ def inflate_jyutping(initial: int, nucleus: int, coda: int, tone: int) -> str:
115
+ return f"{inflate_initial(initial)}{inflate_nucleus(nucleus)}{inflate_coda(coda)}{tone + 1}"
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ librosa
whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f85b07f5287b93c473b234696b387a6b8ff1414bc1980b811f7933f9ecddb28
3
+ size 390773292
whisper_audio_classifier.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import WhisperModel
2
+ from torch import nn
3
+ import torch
4
+ from jyutping import jyutping_initials, jyutping_nuclei, jyutping_codas
5
+
6
+ class WhisperAudioClassifier(nn.Module):
7
+ def __init__(self):
8
+ super(WhisperAudioClassifier, self).__init__()
9
+ # Load the Whisper model encoder
10
+ self.whisper_encoder = WhisperModel.from_pretrained(f"alvanlii/whisper-small-cantonese", device_map="auto").get_encoder()
11
+ self.whisper_encoder.eval() # Set the Whisper model to evaluation mode
12
+
13
+ # Assuming we know the output size of the Whisper encoder, or it needs to be determined
14
+ whisper_output_size = 768
15
+
16
+ self.tone_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
17
+ self.initial_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
18
+ self.nucleus_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
19
+ self.coda_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
20
+
21
+ self.pool = nn.AdaptiveAvgPool1d(1)
22
+
23
+ # Separate output layers for each class set
24
+ self.initial_fc1 = nn.Linear(whisper_output_size, len(jyutping_initials))
25
+ self.nucleus_fc1 = nn.Linear(whisper_output_size, len(jyutping_nuclei))
26
+ self.coda_fc1 = nn.Linear(whisper_output_size, len(jyutping_codas))
27
+ self.tone_fc1 = nn.Linear(whisper_output_size, 6)
28
+
29
+ self.initial_fc2 = nn.Linear(whisper_output_size, len(jyutping_initials))
30
+ self.nucleus_fc2 = nn.Linear(whisper_output_size, len(jyutping_nuclei))
31
+ self.coda_fc2 = nn.Linear(whisper_output_size, len(jyutping_codas))
32
+ self.tone_fc2 = nn.Linear(whisper_output_size, 6)
33
+
34
+ self.dropout = nn.Dropout(0.1)
35
+
36
+ def forward(self, x):
37
+ # Use Whisper model to encode audio input
38
+ with torch.no_grad(): # No need to track gradients for the encoder
39
+ x = self.whisper_encoder(x).last_hidden_state
40
+
41
+ initial, _ = self.initial_attention(x, x, x, need_weights=False)
42
+ initial = initial.permute(0, 2, 1) # [batch_size, channels, seq_len]
43
+ initial = self.pool(initial) # [batch_size, channels, 1]
44
+ initial = initial.squeeze(-1) # [batch_size, channels]
45
+ initial_out1 = self.initial_fc1(initial)
46
+ initial_out2 = self.initial_fc2(initial)
47
+
48
+ nucleus, _ = self.nucleus_attention(x, x, x, need_weights=False)
49
+ nucleus = nucleus.permute(0, 2, 1) # [batch_size, channels, seq_len]
50
+ nucleus = self.pool(nucleus) # [batch_size, channels, 1]
51
+ nucleus = nucleus.squeeze(-1) # [batch_size, channels]
52
+ nucleus_out1 = self.nucleus_fc1(nucleus)
53
+ nucleus_out2 = self.nucleus_fc2(nucleus)
54
+
55
+ coda, _ = self.coda_attention(x, x, x, need_weights=False)
56
+ coda = coda.permute(0, 2, 1) # [batch_size, channels, seq_len]
57
+ coda = self.pool(coda) # [batch_size, channels, 1]
58
+ coda = coda.squeeze(-1) # [batch_size, channels]
59
+ coda_out1 = self.coda_fc1(coda)
60
+ coda_out2 = self.coda_fc2(coda)
61
+
62
+ tone, _ = self.tone_attention(x, x, x, need_weights=False)
63
+ tone = tone.permute(0, 2, 1) # [batch_size, channels, seq_len]
64
+ tone = self.pool(tone) # [batch_size, channels, 1]
65
+ tone = tone.squeeze(-1) # [batch_size, channels]
66
+ tone_out1 = self.tone_fc1(tone)
67
+ tone_out2 = self.tone_fc2(tone)
68
+
69
+ return initial_out1, nucleus_out1, coda_out1, tone_out1, initial_out2, nucleus_out2, coda_out2, tone_out2