Spaces:
Sleeping
Sleeping
AlienKevin
commited on
Commit
•
b256b6f
1
Parent(s):
7741af3
Add app.py and model
Browse files- .gitignore +1 -0
- app.py +91 -0
- jyutping.py +115 -0
- requirements.txt +3 -0
- whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth +3 -0
- whisper_audio_classifier.py +69 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.DS_Store
|
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import jyutping
|
3 |
+
from whisper_audio_classifier import WhisperAudioClassifier
|
4 |
+
import librosa
|
5 |
+
from transformers import WhisperFeatureExtractor
|
6 |
+
|
7 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(f"alvanlii/whisper-small-cantonese")
|
8 |
+
feature_extractor.chunk_length = 3
|
9 |
+
|
10 |
+
# Instantiate the model
|
11 |
+
device = torch.device("mps")
|
12 |
+
model = WhisperAudioClassifier().to(device)
|
13 |
+
|
14 |
+
# Load the state dict
|
15 |
+
state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth")
|
16 |
+
|
17 |
+
# Load the state dict into the model
|
18 |
+
model.load_state_dict(state_dict)
|
19 |
+
|
20 |
+
# Set the model to evaluation mode
|
21 |
+
model.eval()
|
22 |
+
|
23 |
+
def predict(audio):
|
24 |
+
features = feature_extractor(audio, sampling_rate=16000)
|
25 |
+
with torch.no_grad():
|
26 |
+
inputs = torch.from_numpy(features['input_features'][0]).to(device)
|
27 |
+
inputs = inputs.unsqueeze(0) # Add extra batch dimension in front
|
28 |
+
outs = model(inputs)
|
29 |
+
return [torch.softmax(tensor.squeeze(), dim=0).tolist() for tensor in outs]
|
30 |
+
|
31 |
+
import gradio as gr
|
32 |
+
import numpy as np
|
33 |
+
|
34 |
+
def rank_initials(preds, k=3):
|
35 |
+
ranked = sorted([((jyutping.inflate_initial(i) if jyutping.inflate_initial(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
|
36 |
+
return dict(ranked[:k])
|
37 |
+
|
38 |
+
def rank_nucli(preds, k=3):
|
39 |
+
ranked = sorted([((jyutping.inflate_nucleus(i) if jyutping.inflate_nucleus(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
|
40 |
+
return dict(ranked[:k])
|
41 |
+
|
42 |
+
def rank_codas(preds, k=3):
|
43 |
+
ranked = sorted([((jyutping.inflate_coda(i) if jyutping.inflate_coda(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
|
44 |
+
return dict(ranked[:k])
|
45 |
+
|
46 |
+
def rank_tones(preds, k=3):
|
47 |
+
ranked = sorted([(str(i + 1), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
|
48 |
+
return dict(ranked[:k])
|
49 |
+
|
50 |
+
def classify_audio(audio):
|
51 |
+
sampling_rate, audio = audio
|
52 |
+
audio = audio.astype(np.float32)
|
53 |
+
audio /= np.max(np.abs(audio))
|
54 |
+
audio_resampled = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
|
55 |
+
preds = predict(torch.from_numpy(audio_resampled))
|
56 |
+
return [
|
57 |
+
rank_initials(preds[0]),
|
58 |
+
rank_nucli(preds[1]),
|
59 |
+
rank_codas(preds[2]),
|
60 |
+
rank_tones(preds[3]),
|
61 |
+
rank_initials(preds[4]),
|
62 |
+
rank_nucli(preds[5]),
|
63 |
+
rank_codas(preds[6]),
|
64 |
+
rank_tones(preds[7]),
|
65 |
+
]
|
66 |
+
|
67 |
+
with gr.Blocks() as demo:
|
68 |
+
with gr.Row():
|
69 |
+
inputs = gr.Audio(source="microphone", type="numpy", label="Input Audio")
|
70 |
+
submit_btn = gr.Button("Submit")
|
71 |
+
|
72 |
+
with gr.Row():
|
73 |
+
with gr.Column():
|
74 |
+
outputs_left = [
|
75 |
+
gr.Label(label="Initial 1"),
|
76 |
+
gr.Label(label="Nucleus 1"),
|
77 |
+
gr.Label(label="Coda 1"),
|
78 |
+
gr.Label(label="Tone 1"),
|
79 |
+
]
|
80 |
+
|
81 |
+
with gr.Column():
|
82 |
+
outputs_right = [
|
83 |
+
gr.Label(label="Initial 2"),
|
84 |
+
gr.Label(label="Nucleus 2"),
|
85 |
+
gr.Label(label="Coda 2"),
|
86 |
+
gr.Label(label="Tone 2"),
|
87 |
+
]
|
88 |
+
|
89 |
+
submit_btn.click(fn=classify_audio, inputs=inputs, outputs=outputs_left+outputs_right)
|
90 |
+
|
91 |
+
demo.launch()
|
jyutping.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def extract_tone(syllable: str) -> int:
|
2 |
+
return int(syllable[-1]) - 1
|
3 |
+
jyutping_initials = ['∅', 'ng', 'gw', 'kw', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'w', 'z', 'c', 's', 'j']
|
4 |
+
|
5 |
+
def extract_initial(syllable: str) -> (int, str):
|
6 |
+
for i, initial in enumerate(jyutping_initials):
|
7 |
+
if syllable.startswith(initial):
|
8 |
+
return (i, initial)
|
9 |
+
return (0, '')
|
10 |
+
|
11 |
+
def inflate_initial(initial: int) -> str:
|
12 |
+
return jyutping_initials[initial] if initial != 0 else ''
|
13 |
+
jyutping_nuclei = ['∅', 'aa', 'yu', 'eo', 'oe', 'a', 'i', 'u', 'e', 'o']
|
14 |
+
|
15 |
+
def extract_nucleus(syllable: str, initial: str) -> int:
|
16 |
+
syllable = syllable[len(initial):]
|
17 |
+
for i, nucleus in enumerate(jyutping_nuclei):
|
18 |
+
if syllable.startswith(nucleus):
|
19 |
+
return (i, nucleus)
|
20 |
+
return (0, '')
|
21 |
+
|
22 |
+
def inflate_nucleus(nucleus: int) -> str:
|
23 |
+
return jyutping_nuclei[nucleus] if nucleus != 0 else ''
|
24 |
+
jyutping_codas = ['∅', 'ng', 'p', 't', 'k', 'm', 'n', 'i', 'u']
|
25 |
+
|
26 |
+
def extract_coda(syllable: str, initial: str, nucleus: str) -> int:
|
27 |
+
syllable = syllable[len(initial) + len(nucleus):]
|
28 |
+
for i, coda in enumerate(jyutping_codas):
|
29 |
+
if syllable.startswith(coda):
|
30 |
+
return (i, coda)
|
31 |
+
return (0, '')
|
32 |
+
|
33 |
+
def inflate_coda(coda: int) -> str:
|
34 |
+
return jyutping_codas[coda] if coda != 0 else ''
|
35 |
+
syllable = 'neoi5'
|
36 |
+
|
37 |
+
initial_int, initial = extract_initial(syllable)
|
38 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
39 |
+
coda_int, coda = extract_coda(syllable, initial, nucleus)
|
40 |
+
|
41 |
+
assert(initial == 'n' and initial_int == 10)
|
42 |
+
assert(nucleus == 'eo' and nucleus_int == 3)
|
43 |
+
assert(coda == 'i' and coda_int == 7)
|
44 |
+
syllable = 'gwok3'
|
45 |
+
|
46 |
+
initial_int, initial = extract_initial(syllable)
|
47 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
48 |
+
coda_int, coda = extract_coda(syllable, initial, nucleus)
|
49 |
+
|
50 |
+
assert(initial == 'gw' and initial_int == 2)
|
51 |
+
assert(nucleus == 'o' and nucleus_int == 9)
|
52 |
+
assert(coda == 'k' and coda_int == 4)
|
53 |
+
syllable = 'oi3'
|
54 |
+
|
55 |
+
initial_int, initial = extract_initial(syllable)
|
56 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
57 |
+
coda_int, coda = extract_coda(syllable, initial, nucleus)
|
58 |
+
|
59 |
+
assert(initial == '' and initial_int == 0)
|
60 |
+
assert(nucleus == 'o' and nucleus_int == 9)
|
61 |
+
assert(coda == 'i' and coda_int == 7)
|
62 |
+
syllable = 'ng4'
|
63 |
+
|
64 |
+
initial_int, initial = extract_initial(syllable)
|
65 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
66 |
+
coda_int, coda = extract_coda(syllable, initial, nucleus)
|
67 |
+
|
68 |
+
assert(initial == 'ng' and initial_int == 1)
|
69 |
+
assert(nucleus == '' and nucleus_int == 0)
|
70 |
+
assert(coda == '' and coda_int == 0)
|
71 |
+
syllable = 'ngo5'
|
72 |
+
|
73 |
+
initial_int, initial = extract_initial(syllable)
|
74 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
75 |
+
coda_int, coda = extract_coda(syllable, initial, nucleus)
|
76 |
+
|
77 |
+
assert(initial == 'ng' and initial_int == 1)
|
78 |
+
assert(nucleus == 'o' and nucleus_int == 9)
|
79 |
+
assert(coda == '' and coda_int == 0)
|
80 |
+
syllable = 'a3'
|
81 |
+
|
82 |
+
initial_int, initial = extract_initial(syllable)
|
83 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
84 |
+
coda_int, coda = extract_coda(syllable, initial, nucleus)
|
85 |
+
|
86 |
+
assert(initial == '' and initial_int == 0)
|
87 |
+
assert(nucleus == 'a' and nucleus_int == 5)
|
88 |
+
assert(coda == '' and coda_int == 0)
|
89 |
+
syllable = 'aa3'
|
90 |
+
|
91 |
+
initial_int, initial = extract_initial(syllable)
|
92 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
93 |
+
coda_int, coda = extract_coda(syllable, initial, nucleus)
|
94 |
+
|
95 |
+
assert(initial == '' and initial_int == 0)
|
96 |
+
assert(nucleus == 'aa' and nucleus_int == 1)
|
97 |
+
assert(coda == '' and coda_int == 0)
|
98 |
+
syllable = 'ngaang6'
|
99 |
+
|
100 |
+
initial_int, initial = extract_initial(syllable)
|
101 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
102 |
+
coda_int, coda = extract_coda(syllable, initial, nucleus)
|
103 |
+
|
104 |
+
assert(initial == 'ng' and initial_int == 1)
|
105 |
+
assert(nucleus == 'aa' and nucleus_int == 1)
|
106 |
+
assert(coda == 'ng' and coda_int == 1)
|
107 |
+
def extract_jyutping(syllable: str) -> (int, int, int, int):
|
108 |
+
initial_int, initial = extract_initial(syllable)
|
109 |
+
nucleus_int, nucleus = extract_nucleus(syllable, initial)
|
110 |
+
coda_int, _ = extract_coda(syllable, initial, nucleus)
|
111 |
+
tone = extract_tone(syllable)
|
112 |
+
return (initial_int, nucleus_int, coda_int, tone)
|
113 |
+
|
114 |
+
def inflate_jyutping(initial: int, nucleus: int, coda: int, tone: int) -> str:
|
115 |
+
return f"{inflate_initial(initial)}{inflate_nucleus(nucleus)}{inflate_coda(coda)}{tone + 1}"
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
librosa
|
whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f85b07f5287b93c473b234696b387a6b8ff1414bc1980b811f7933f9ecddb28
|
3 |
+
size 390773292
|
whisper_audio_classifier.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import WhisperModel
|
2 |
+
from torch import nn
|
3 |
+
import torch
|
4 |
+
from jyutping import jyutping_initials, jyutping_nuclei, jyutping_codas
|
5 |
+
|
6 |
+
class WhisperAudioClassifier(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(WhisperAudioClassifier, self).__init__()
|
9 |
+
# Load the Whisper model encoder
|
10 |
+
self.whisper_encoder = WhisperModel.from_pretrained(f"alvanlii/whisper-small-cantonese", device_map="auto").get_encoder()
|
11 |
+
self.whisper_encoder.eval() # Set the Whisper model to evaluation mode
|
12 |
+
|
13 |
+
# Assuming we know the output size of the Whisper encoder, or it needs to be determined
|
14 |
+
whisper_output_size = 768
|
15 |
+
|
16 |
+
self.tone_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
|
17 |
+
self.initial_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
|
18 |
+
self.nucleus_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
|
19 |
+
self.coda_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True)
|
20 |
+
|
21 |
+
self.pool = nn.AdaptiveAvgPool1d(1)
|
22 |
+
|
23 |
+
# Separate output layers for each class set
|
24 |
+
self.initial_fc1 = nn.Linear(whisper_output_size, len(jyutping_initials))
|
25 |
+
self.nucleus_fc1 = nn.Linear(whisper_output_size, len(jyutping_nuclei))
|
26 |
+
self.coda_fc1 = nn.Linear(whisper_output_size, len(jyutping_codas))
|
27 |
+
self.tone_fc1 = nn.Linear(whisper_output_size, 6)
|
28 |
+
|
29 |
+
self.initial_fc2 = nn.Linear(whisper_output_size, len(jyutping_initials))
|
30 |
+
self.nucleus_fc2 = nn.Linear(whisper_output_size, len(jyutping_nuclei))
|
31 |
+
self.coda_fc2 = nn.Linear(whisper_output_size, len(jyutping_codas))
|
32 |
+
self.tone_fc2 = nn.Linear(whisper_output_size, 6)
|
33 |
+
|
34 |
+
self.dropout = nn.Dropout(0.1)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
# Use Whisper model to encode audio input
|
38 |
+
with torch.no_grad(): # No need to track gradients for the encoder
|
39 |
+
x = self.whisper_encoder(x).last_hidden_state
|
40 |
+
|
41 |
+
initial, _ = self.initial_attention(x, x, x, need_weights=False)
|
42 |
+
initial = initial.permute(0, 2, 1) # [batch_size, channels, seq_len]
|
43 |
+
initial = self.pool(initial) # [batch_size, channels, 1]
|
44 |
+
initial = initial.squeeze(-1) # [batch_size, channels]
|
45 |
+
initial_out1 = self.initial_fc1(initial)
|
46 |
+
initial_out2 = self.initial_fc2(initial)
|
47 |
+
|
48 |
+
nucleus, _ = self.nucleus_attention(x, x, x, need_weights=False)
|
49 |
+
nucleus = nucleus.permute(0, 2, 1) # [batch_size, channels, seq_len]
|
50 |
+
nucleus = self.pool(nucleus) # [batch_size, channels, 1]
|
51 |
+
nucleus = nucleus.squeeze(-1) # [batch_size, channels]
|
52 |
+
nucleus_out1 = self.nucleus_fc1(nucleus)
|
53 |
+
nucleus_out2 = self.nucleus_fc2(nucleus)
|
54 |
+
|
55 |
+
coda, _ = self.coda_attention(x, x, x, need_weights=False)
|
56 |
+
coda = coda.permute(0, 2, 1) # [batch_size, channels, seq_len]
|
57 |
+
coda = self.pool(coda) # [batch_size, channels, 1]
|
58 |
+
coda = coda.squeeze(-1) # [batch_size, channels]
|
59 |
+
coda_out1 = self.coda_fc1(coda)
|
60 |
+
coda_out2 = self.coda_fc2(coda)
|
61 |
+
|
62 |
+
tone, _ = self.tone_attention(x, x, x, need_weights=False)
|
63 |
+
tone = tone.permute(0, 2, 1) # [batch_size, channels, seq_len]
|
64 |
+
tone = self.pool(tone) # [batch_size, channels, 1]
|
65 |
+
tone = tone.squeeze(-1) # [batch_size, channels]
|
66 |
+
tone_out1 = self.tone_fc1(tone)
|
67 |
+
tone_out2 = self.tone_fc2(tone)
|
68 |
+
|
69 |
+
return initial_out1, nucleus_out1, coda_out1, tone_out1, initial_out2, nucleus_out2, coda_out2, tone_out2
|