Update whisper-speaker-diarization/src/worker.js
Browse files
whisper-speaker-diarization/src/worker.js
CHANGED
@@ -1,124 +1,124 @@
|
|
1 |
-
|
2 |
-
import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@xenova/transformers';
|
3 |
-
|
4 |
-
const PER_DEVICE_CONFIG = {
|
5 |
-
webgpu: {
|
6 |
-
dtype: {
|
7 |
-
encoder_model: 'fp32',
|
8 |
-
decoder_model_merged: 'q4',
|
9 |
-
},
|
10 |
-
device: 'webgpu',
|
11 |
-
},
|
12 |
-
wasm: {
|
13 |
-
dtype: 'q8',
|
14 |
-
device: 'wasm',
|
15 |
-
},
|
16 |
-
};
|
17 |
-
|
18 |
-
/**
|
19 |
-
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
|
20 |
-
*/
|
21 |
-
class PipelineSingeton {
|
22 |
-
static asr_model_id = '
|
23 |
-
static asr_instance = null;
|
24 |
-
|
25 |
-
static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
|
26 |
-
static segmentation_instance = null;
|
27 |
-
static segmentation_processor = null;
|
28 |
-
|
29 |
-
static async getInstance(progress_callback = null, device = 'webgpu') {
|
30 |
-
this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
|
31 |
-
...PER_DEVICE_CONFIG[device],
|
32 |
-
progress_callback,
|
33 |
-
});
|
34 |
-
|
35 |
-
this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
|
36 |
-
progress_callback,
|
37 |
-
});
|
38 |
-
this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
|
39 |
-
// NOTE: WebGPU is not currently supported for this model
|
40 |
-
// See https://github.com/microsoft/onnxruntime/issues/21386
|
41 |
-
device: 'wasm',
|
42 |
-
dtype: 'fp32',
|
43 |
-
progress_callback,
|
44 |
-
});
|
45 |
-
|
46 |
-
return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
|
47 |
-
}
|
48 |
-
}
|
49 |
-
|
50 |
-
async function load({ device }) {
|
51 |
-
self.postMessage({
|
52 |
-
status: 'loading',
|
53 |
-
data: `Loading models (${device})...`
|
54 |
-
});
|
55 |
-
|
56 |
-
// Load the pipeline and save it for future use.
|
57 |
-
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
|
58 |
-
// We also add a progress callback to the pipeline so that we can
|
59 |
-
// track model loading.
|
60 |
-
self.postMessage(x);
|
61 |
-
}, device);
|
62 |
-
|
63 |
-
if (device === 'webgpu') {
|
64 |
-
self.postMessage({
|
65 |
-
status: 'loading',
|
66 |
-
data: 'Compiling shaders and warming up model...'
|
67 |
-
});
|
68 |
-
|
69 |
-
await transcriber(new Float32Array(16_000), {
|
70 |
-
language: 'en',
|
71 |
-
});
|
72 |
-
}
|
73 |
-
|
74 |
-
self.postMessage({ status: 'loaded' });
|
75 |
-
}
|
76 |
-
|
77 |
-
async function segment(processor, model, audio) {
|
78 |
-
const inputs = await processor(audio);
|
79 |
-
const { logits } = await model(inputs);
|
80 |
-
const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
|
81 |
-
|
82 |
-
// Attach labels
|
83 |
-
for (const segment of segments) {
|
84 |
-
segment.label = model.config.id2label[segment.id];
|
85 |
-
}
|
86 |
-
|
87 |
-
return segments;
|
88 |
-
}
|
89 |
-
|
90 |
-
async function run({ audio, language }) {
|
91 |
-
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
|
92 |
-
|
93 |
-
const start = performance.now();
|
94 |
-
|
95 |
-
// Run transcription and segmentation in parallel
|
96 |
-
const [transcript, segments] = await Promise.all([
|
97 |
-
transcriber(audio, {
|
98 |
-
language,
|
99 |
-
return_timestamps: 'word',
|
100 |
-
chunk_length_s: 30,
|
101 |
-
}),
|
102 |
-
segment(segmentation_processor, segmentation_model, audio)
|
103 |
-
]);
|
104 |
-
console.table(segments, ['start', 'end', 'id', 'label', 'confidence']);
|
105 |
-
|
106 |
-
const end = performance.now();
|
107 |
-
|
108 |
-
self.postMessage({ status: 'complete', result: { transcript, segments }, time: end - start });
|
109 |
-
}
|
110 |
-
|
111 |
-
// Listen for messages from the main thread
|
112 |
-
self.addEventListener('message', async (e) => {
|
113 |
-
const { type, data } = e.data;
|
114 |
-
|
115 |
-
switch (type) {
|
116 |
-
case 'load':
|
117 |
-
load(data);
|
118 |
-
break;
|
119 |
-
|
120 |
-
case 'run':
|
121 |
-
run(data);
|
122 |
-
break;
|
123 |
-
}
|
124 |
-
});
|
|
|
1 |
+
|
2 |
+
import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@xenova/transformers';
|
3 |
+
|
4 |
+
const PER_DEVICE_CONFIG = {
|
5 |
+
webgpu: {
|
6 |
+
dtype: {
|
7 |
+
encoder_model: 'fp32',
|
8 |
+
decoder_model_merged: 'q4',
|
9 |
+
},
|
10 |
+
device: 'webgpu',
|
11 |
+
},
|
12 |
+
wasm: {
|
13 |
+
dtype: 'q8',
|
14 |
+
device: 'wasm',
|
15 |
+
},
|
16 |
+
};
|
17 |
+
|
18 |
+
/**
|
19 |
+
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
|
20 |
+
*/
|
21 |
+
class PipelineSingeton {
|
22 |
+
static asr_model_id = 'Xenova/whisper-large-v3';
|
23 |
+
static asr_instance = null;
|
24 |
+
|
25 |
+
static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
|
26 |
+
static segmentation_instance = null;
|
27 |
+
static segmentation_processor = null;
|
28 |
+
|
29 |
+
static async getInstance(progress_callback = null, device = 'webgpu') {
|
30 |
+
this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
|
31 |
+
...PER_DEVICE_CONFIG[device],
|
32 |
+
progress_callback,
|
33 |
+
});
|
34 |
+
|
35 |
+
this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
|
36 |
+
progress_callback,
|
37 |
+
});
|
38 |
+
this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
|
39 |
+
// NOTE: WebGPU is not currently supported for this model
|
40 |
+
// See https://github.com/microsoft/onnxruntime/issues/21386
|
41 |
+
device: 'wasm',
|
42 |
+
dtype: 'fp32',
|
43 |
+
progress_callback,
|
44 |
+
});
|
45 |
+
|
46 |
+
return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
async function load({ device }) {
|
51 |
+
self.postMessage({
|
52 |
+
status: 'loading',
|
53 |
+
data: `Loading models (${device})...`
|
54 |
+
});
|
55 |
+
|
56 |
+
// Load the pipeline and save it for future use.
|
57 |
+
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
|
58 |
+
// We also add a progress callback to the pipeline so that we can
|
59 |
+
// track model loading.
|
60 |
+
self.postMessage(x);
|
61 |
+
}, device);
|
62 |
+
|
63 |
+
if (device === 'webgpu') {
|
64 |
+
self.postMessage({
|
65 |
+
status: 'loading',
|
66 |
+
data: 'Compiling shaders and warming up model...'
|
67 |
+
});
|
68 |
+
|
69 |
+
await transcriber(new Float32Array(16_000), {
|
70 |
+
language: 'en',
|
71 |
+
});
|
72 |
+
}
|
73 |
+
|
74 |
+
self.postMessage({ status: 'loaded' });
|
75 |
+
}
|
76 |
+
|
77 |
+
async function segment(processor, model, audio) {
|
78 |
+
const inputs = await processor(audio);
|
79 |
+
const { logits } = await model(inputs);
|
80 |
+
const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
|
81 |
+
|
82 |
+
// Attach labels
|
83 |
+
for (const segment of segments) {
|
84 |
+
segment.label = model.config.id2label[segment.id];
|
85 |
+
}
|
86 |
+
|
87 |
+
return segments;
|
88 |
+
}
|
89 |
+
|
90 |
+
async function run({ audio, language }) {
|
91 |
+
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
|
92 |
+
|
93 |
+
const start = performance.now();
|
94 |
+
|
95 |
+
// Run transcription and segmentation in parallel
|
96 |
+
const [transcript, segments] = await Promise.all([
|
97 |
+
transcriber(audio, {
|
98 |
+
language,
|
99 |
+
return_timestamps: 'word',
|
100 |
+
chunk_length_s: 30,
|
101 |
+
}),
|
102 |
+
segment(segmentation_processor, segmentation_model, audio)
|
103 |
+
]);
|
104 |
+
console.table(segments, ['start', 'end', 'id', 'label', 'confidence']);
|
105 |
+
|
106 |
+
const end = performance.now();
|
107 |
+
|
108 |
+
self.postMessage({ status: 'complete', result: { transcript, segments }, time: end - start });
|
109 |
+
}
|
110 |
+
|
111 |
+
// Listen for messages from the main thread
|
112 |
+
self.addEventListener('message', async (e) => {
|
113 |
+
const { type, data } = e.data;
|
114 |
+
|
115 |
+
switch (type) {
|
116 |
+
case 'load':
|
117 |
+
load(data);
|
118 |
+
break;
|
119 |
+
|
120 |
+
case 'run':
|
121 |
+
run(data);
|
122 |
+
break;
|
123 |
+
}
|
124 |
+
});
|