natbutter commited on
Commit
3487feb
1 Parent(s): 6076ee9

Update whisper-speaker-diarization/src/worker.js

Browse files
whisper-speaker-diarization/src/worker.js CHANGED
@@ -1,124 +1,124 @@
1
-
2
- import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@xenova/transformers';
3
-
4
- const PER_DEVICE_CONFIG = {
5
- webgpu: {
6
- dtype: {
7
- encoder_model: 'fp32',
8
- decoder_model_merged: 'q4',
9
- },
10
- device: 'webgpu',
11
- },
12
- wasm: {
13
- dtype: 'q8',
14
- device: 'wasm',
15
- },
16
- };
17
-
18
- /**
19
- * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
20
- */
21
- class PipelineSingeton {
22
- static asr_model_id = 'onnx-community/whisper-base_timestamped';
23
- static asr_instance = null;
24
-
25
- static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
26
- static segmentation_instance = null;
27
- static segmentation_processor = null;
28
-
29
- static async getInstance(progress_callback = null, device = 'webgpu') {
30
- this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
31
- ...PER_DEVICE_CONFIG[device],
32
- progress_callback,
33
- });
34
-
35
- this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
36
- progress_callback,
37
- });
38
- this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
39
- // NOTE: WebGPU is not currently supported for this model
40
- // See https://github.com/microsoft/onnxruntime/issues/21386
41
- device: 'wasm',
42
- dtype: 'fp32',
43
- progress_callback,
44
- });
45
-
46
- return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
47
- }
48
- }
49
-
50
- async function load({ device }) {
51
- self.postMessage({
52
- status: 'loading',
53
- data: `Loading models (${device})...`
54
- });
55
-
56
- // Load the pipeline and save it for future use.
57
- const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
58
- // We also add a progress callback to the pipeline so that we can
59
- // track model loading.
60
- self.postMessage(x);
61
- }, device);
62
-
63
- if (device === 'webgpu') {
64
- self.postMessage({
65
- status: 'loading',
66
- data: 'Compiling shaders and warming up model...'
67
- });
68
-
69
- await transcriber(new Float32Array(16_000), {
70
- language: 'en',
71
- });
72
- }
73
-
74
- self.postMessage({ status: 'loaded' });
75
- }
76
-
77
- async function segment(processor, model, audio) {
78
- const inputs = await processor(audio);
79
- const { logits } = await model(inputs);
80
- const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
81
-
82
- // Attach labels
83
- for (const segment of segments) {
84
- segment.label = model.config.id2label[segment.id];
85
- }
86
-
87
- return segments;
88
- }
89
-
90
- async function run({ audio, language }) {
91
- const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
92
-
93
- const start = performance.now();
94
-
95
- // Run transcription and segmentation in parallel
96
- const [transcript, segments] = await Promise.all([
97
- transcriber(audio, {
98
- language,
99
- return_timestamps: 'word',
100
- chunk_length_s: 30,
101
- }),
102
- segment(segmentation_processor, segmentation_model, audio)
103
- ]);
104
- console.table(segments, ['start', 'end', 'id', 'label', 'confidence']);
105
-
106
- const end = performance.now();
107
-
108
- self.postMessage({ status: 'complete', result: { transcript, segments }, time: end - start });
109
- }
110
-
111
- // Listen for messages from the main thread
112
- self.addEventListener('message', async (e) => {
113
- const { type, data } = e.data;
114
-
115
- switch (type) {
116
- case 'load':
117
- load(data);
118
- break;
119
-
120
- case 'run':
121
- run(data);
122
- break;
123
- }
124
- });
 
1
+
2
+ import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@xenova/transformers';
3
+
4
+ const PER_DEVICE_CONFIG = {
5
+ webgpu: {
6
+ dtype: {
7
+ encoder_model: 'fp32',
8
+ decoder_model_merged: 'q4',
9
+ },
10
+ device: 'webgpu',
11
+ },
12
+ wasm: {
13
+ dtype: 'q8',
14
+ device: 'wasm',
15
+ },
16
+ };
17
+
18
+ /**
19
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
20
+ */
21
+ class PipelineSingeton {
22
+ static asr_model_id = 'Xenova/whisper-large-v3';
23
+ static asr_instance = null;
24
+
25
+ static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
26
+ static segmentation_instance = null;
27
+ static segmentation_processor = null;
28
+
29
+ static async getInstance(progress_callback = null, device = 'webgpu') {
30
+ this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
31
+ ...PER_DEVICE_CONFIG[device],
32
+ progress_callback,
33
+ });
34
+
35
+ this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
36
+ progress_callback,
37
+ });
38
+ this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
39
+ // NOTE: WebGPU is not currently supported for this model
40
+ // See https://github.com/microsoft/onnxruntime/issues/21386
41
+ device: 'wasm',
42
+ dtype: 'fp32',
43
+ progress_callback,
44
+ });
45
+
46
+ return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
47
+ }
48
+ }
49
+
50
+ async function load({ device }) {
51
+ self.postMessage({
52
+ status: 'loading',
53
+ data: `Loading models (${device})...`
54
+ });
55
+
56
+ // Load the pipeline and save it for future use.
57
+ const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
58
+ // We also add a progress callback to the pipeline so that we can
59
+ // track model loading.
60
+ self.postMessage(x);
61
+ }, device);
62
+
63
+ if (device === 'webgpu') {
64
+ self.postMessage({
65
+ status: 'loading',
66
+ data: 'Compiling shaders and warming up model...'
67
+ });
68
+
69
+ await transcriber(new Float32Array(16_000), {
70
+ language: 'en',
71
+ });
72
+ }
73
+
74
+ self.postMessage({ status: 'loaded' });
75
+ }
76
+
77
+ async function segment(processor, model, audio) {
78
+ const inputs = await processor(audio);
79
+ const { logits } = await model(inputs);
80
+ const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
81
+
82
+ // Attach labels
83
+ for (const segment of segments) {
84
+ segment.label = model.config.id2label[segment.id];
85
+ }
86
+
87
+ return segments;
88
+ }
89
+
90
+ async function run({ audio, language }) {
91
+ const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
92
+
93
+ const start = performance.now();
94
+
95
+ // Run transcription and segmentation in parallel
96
+ const [transcript, segments] = await Promise.all([
97
+ transcriber(audio, {
98
+ language,
99
+ return_timestamps: 'word',
100
+ chunk_length_s: 30,
101
+ }),
102
+ segment(segmentation_processor, segmentation_model, audio)
103
+ ]);
104
+ console.table(segments, ['start', 'end', 'id', 'label', 'confidence']);
105
+
106
+ const end = performance.now();
107
+
108
+ self.postMessage({ status: 'complete', result: { transcript, segments }, time: end - start });
109
+ }
110
+
111
+ // Listen for messages from the main thread
112
+ self.addEventListener('message', async (e) => {
113
+ const { type, data } = e.data;
114
+
115
+ switch (type) {
116
+ case 'load':
117
+ load(data);
118
+ break;
119
+
120
+ case 'run':
121
+ run(data);
122
+ break;
123
+ }
124
+ });