Rakot2223 aadnk commited on
Commit
cb1db42
0 Parent(s):

Duplicate from aadnk/faster-whisper-webui

Browse files

Co-authored-by: Kristian Stangeland <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ .vscode/
4
+ flagged/
5
+ *.py[cod]
6
+ *$py.class
LICENSE.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ ==============
3
+
4
+ _Version 2.0, January 2004_
5
+ _&lt;<http://www.apache.org/licenses/>&gt;_
6
+
7
+ ### Terms and Conditions for use, reproduction, and distribution
8
+
9
+ #### 1. Definitions
10
+
11
+ “License” shall mean the terms and conditions for use, reproduction, and
12
+ distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright
15
+ owner that is granting the License.
16
+
17
+ “Legal Entity” shall mean the union of the acting entity and all other entities
18
+ that control, are controlled by, or are under common control with that entity.
19
+ For the purposes of this definition, “control” means **(i)** the power, direct or
20
+ indirect, to cause the direction or management of such entity, whether by
21
+ contract or otherwise, or **(ii)** ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or **(iii)** beneficial ownership of such entity.
23
+
24
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising
25
+ permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including
28
+ but not limited to software source code, documentation source, and configuration
29
+ files.
30
+
31
+ “Object” form shall mean any form resulting from mechanical transformation or
32
+ translation of a Source form, including but not limited to compiled object code,
33
+ generated documentation, and conversions to other media types.
34
+
35
+ “Work” shall mean the work of authorship, whether in Source or Object form, made
36
+ available under the License, as indicated by a copyright notice that is included
37
+ in or attached to the work (an example is provided in the Appendix below).
38
+
39
+ “Derivative Works” shall mean any work, whether in Source or Object form, that
40
+ is based on (or derived from) the Work and for which the editorial revisions,
41
+ annotations, elaborations, or other modifications represent, as a whole, an
42
+ original work of authorship. For the purposes of this License, Derivative Works
43
+ shall not include works that remain separable from, or merely link (or bind by
44
+ name) to the interfaces of, the Work and Derivative Works thereof.
45
+
46
+ “Contribution” shall mean any work of authorship, including the original version
47
+ of the Work and any modifications or additions to that Work or Derivative Works
48
+ thereof, that is intentionally submitted to Licensor for inclusion in the Work
49
+ by the copyright owner or by an individual or Legal Entity authorized to submit
50
+ on behalf of the copyright owner. For the purposes of this definition,
51
+ “submitted” means any form of electronic, verbal, or written communication sent
52
+ to the Licensor or its representatives, including but not limited to
53
+ communication on electronic mailing lists, source code control systems, and
54
+ issue tracking systems that are managed by, or on behalf of, the Licensor for
55
+ the purpose of discussing and improving the Work, but excluding communication
56
+ that is conspicuously marked or otherwise designated in writing by the copyright
57
+ owner as “Not a Contribution.”
58
+
59
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf
60
+ of whom a Contribution has been received by Licensor and subsequently
61
+ incorporated within the Work.
62
+
63
+ #### 2. Grant of Copyright License
64
+
65
+ Subject to the terms and conditions of this License, each Contributor hereby
66
+ grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
67
+ irrevocable copyright license to reproduce, prepare Derivative Works of,
68
+ publicly display, publicly perform, sublicense, and distribute the Work and such
69
+ Derivative Works in Source or Object form.
70
+
71
+ #### 3. Grant of Patent License
72
+
73
+ Subject to the terms and conditions of this License, each Contributor hereby
74
+ grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
75
+ irrevocable (except as stated in this section) patent license to make, have
76
+ made, use, offer to sell, sell, import, and otherwise transfer the Work, where
77
+ such license applies only to those patent claims licensable by such Contributor
78
+ that are necessarily infringed by their Contribution(s) alone or by combination
79
+ of their Contribution(s) with the Work to which such Contribution(s) was
80
+ submitted. If You institute patent litigation against any entity (including a
81
+ cross-claim or counterclaim in a lawsuit) alleging that the Work or a
82
+ Contribution incorporated within the Work constitutes direct or contributory
83
+ patent infringement, then any patent licenses granted to You under this License
84
+ for that Work shall terminate as of the date such litigation is filed.
85
+
86
+ #### 4. Redistribution
87
+
88
+ You may reproduce and distribute copies of the Work or Derivative Works thereof
89
+ in any medium, with or without modifications, and in Source or Object form,
90
+ provided that You meet the following conditions:
91
+
92
+ * **(a)** You must give any other recipients of the Work or Derivative Works a copy of
93
+ this License; and
94
+ * **(b)** You must cause any modified files to carry prominent notices stating that You
95
+ changed the files; and
96
+ * **(c)** You must retain, in the Source form of any Derivative Works that You distribute,
97
+ all copyright, patent, trademark, and attribution notices from the Source form
98
+ of the Work, excluding those notices that do not pertain to any part of the
99
+ Derivative Works; and
100
+ * **(d)** If the Work includes a “NOTICE” text file as part of its distribution, then any
101
+ Derivative Works that You distribute must include a readable copy of the
102
+ attribution notices contained within such NOTICE file, excluding those notices
103
+ that do not pertain to any part of the Derivative Works, in at least one of the
104
+ following places: within a NOTICE text file distributed as part of the
105
+ Derivative Works; within the Source form or documentation, if provided along
106
+ with the Derivative Works; or, within a display generated by the Derivative
107
+ Works, if and wherever such third-party notices normally appear. The contents of
108
+ the NOTICE file are for informational purposes only and do not modify the
109
+ License. You may add Your own attribution notices within Derivative Works that
110
+ You distribute, alongside or as an addendum to the NOTICE text from the Work,
111
+ provided that such additional attribution notices cannot be construed as
112
+ modifying the License.
113
+
114
+ You may add Your own copyright statement to Your modifications and may provide
115
+ additional or different license terms and conditions for use, reproduction, or
116
+ distribution of Your modifications, or for any such Derivative Works as a whole,
117
+ provided Your use, reproduction, and distribution of the Work otherwise complies
118
+ with the conditions stated in this License.
119
+
120
+ #### 5. Submission of Contributions
121
+
122
+ Unless You explicitly state otherwise, any Contribution intentionally submitted
123
+ for inclusion in the Work by You to the Licensor shall be under the terms and
124
+ conditions of this License, without any additional terms or conditions.
125
+ Notwithstanding the above, nothing herein shall supersede or modify the terms of
126
+ any separate license agreement you may have executed with Licensor regarding
127
+ such Contributions.
128
+
129
+ #### 6. Trademarks
130
+
131
+ This License does not grant permission to use the trade names, trademarks,
132
+ service marks, or product names of the Licensor, except as required for
133
+ reasonable and customary use in describing the origin of the Work and
134
+ reproducing the content of the NOTICE file.
135
+
136
+ #### 7. Disclaimer of Warranty
137
+
138
+ Unless required by applicable law or agreed to in writing, Licensor provides the
139
+ Work (and each Contributor provides its Contributions) on an “AS IS” BASIS,
140
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
141
+ including, without limitation, any warranties or conditions of TITLE,
142
+ NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
143
+ solely responsible for determining the appropriateness of using or
144
+ redistributing the Work and assume any risks associated with Your exercise of
145
+ permissions under this License.
146
+
147
+ #### 8. Limitation of Liability
148
+
149
+ In no event and under no legal theory, whether in tort (including negligence),
150
+ contract, or otherwise, unless required by applicable law (such as deliberate
151
+ and grossly negligent acts) or agreed to in writing, shall any Contributor be
152
+ liable to You for damages, including any direct, indirect, special, incidental,
153
+ or consequential damages of any character arising as a result of this License or
154
+ out of the use or inability to use the Work (including but not limited to
155
+ damages for loss of goodwill, work stoppage, computer failure or malfunction, or
156
+ any and all other commercial damages or losses), even if such Contributor has
157
+ been advised of the possibility of such damages.
158
+
159
+ #### 9. Accepting Warranty or Additional Liability
160
+
161
+ While redistributing the Work or Derivative Works thereof, You may choose to
162
+ offer, and charge a fee for, acceptance of support, warranty, indemnity, or
163
+ other liability obligations and/or rights consistent with this License. However,
164
+ in accepting such obligations, You may act only on Your own behalf and on Your
165
+ sole responsibility, not on behalf of any other Contributor, and only if You
166
+ agree to indemnify, defend, and hold each Contributor harmless for any liability
167
+ incurred by, or claims asserted against, such Contributor by reason of your
168
+ accepting any such warranty or additional liability.
169
+
170
+ _END OF TERMS AND CONDITIONS_
171
+
172
+ ### APPENDIX: How to apply the Apache License to your work
173
+
174
+ To apply the Apache License to your work, attach the following boilerplate
175
+ notice, with the fields enclosed by brackets `[]` replaced with your own
176
+ identifying information. (Don't include the brackets!) The text should be
177
+ enclosed in the appropriate comment syntax for the file format. We also
178
+ recommend that a file or class name and description of purpose be included on
179
+ the same “printed page” as the copyright notice for easier identification within
180
+ third-party archives.
181
+
182
+ Copyright [yyyy] [name of copyright owner]
183
+
184
+ Licensed under the Apache License, Version 2.0 (the "License");
185
+ you may not use this file except in compliance with the License.
186
+ You may obtain a copy of the License at
187
+
188
+ http://www.apache.org/licenses/LICENSE-2.0
189
+
190
+ Unless required by applicable law or agreed to in writing, software
191
+ distributed under the License is distributed on an "AS IS" BASIS,
192
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
193
+ See the License for the specific language governing permissions and
194
+ limitations under the License.
195
+
README.md ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Faster Whisper Webui
3
+ emoji: 🚀
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.23.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: aadnk/faster-whisper-webui
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ # Running Locally
17
+
18
+ To run this program locally, first install Python 3.9+ and Git. Then install Pytorch 10.1+ and all the other dependencies:
19
+ ```
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ You can find detailed instructions for how to install this on Windows 10/11 [here (PDF)](docs/windows/install_win10_win11.pdf).
24
+
25
+ Finally, run the full version (no audio length restrictions) of the app with parallel CPU/GPU enabled:
26
+ ```
27
+ python app.py --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
28
+ ```
29
+
30
+ You can also run the CLI interface, which is similar to Whisper's own CLI but also supports the following additional arguments:
31
+ ```
32
+ python cli.py \
33
+ [--vad {none,silero-vad,silero-vad-skip-gaps,silero-vad-expand-into-gaps,periodic-vad}] \
34
+ [--vad_merge_window VAD_MERGE_WINDOW] \
35
+ [--vad_max_merge_size VAD_MAX_MERGE_SIZE] \
36
+ [--vad_padding VAD_PADDING] \
37
+ [--vad_prompt_window VAD_PROMPT_WINDOW]
38
+ [--vad_cpu_cores NUMBER_OF_CORES]
39
+ [--vad_parallel_devices COMMA_DELIMITED_DEVICES]
40
+ [--auto_parallel BOOLEAN]
41
+ ```
42
+ In addition, you may also use URL's in addition to file paths as input.
43
+ ```
44
+ python cli.py --model large --vad silero-vad --language Japanese "https://www.youtube.com/watch?v=4cICErqqRSM"
45
+ ```
46
+
47
+ Rather than supplying arguments to `app.py` or `cli.py`, you can also use the configuration file [config.json5](config.json5). See that file for more information.
48
+ If you want to use a different configuration file, you can use the `WHISPER_WEBUI_CONFIG` environment variable to specify the path to another file.
49
+
50
+ ### Multiple Files
51
+
52
+ You can upload multiple files either through the "Upload files" option, or as a playlist on YouTube.
53
+ Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section.
54
+ When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
55
+
56
+ ## Whisper Implementation
57
+
58
+ You can choose between using `whisper` or `faster-whisper`. [Faster Whisper](https://github.com/guillaumekln/faster-whisper) as a drop-in replacement for the
59
+ default Whisper which achieves up to a 4x speedup and 2x reduction in memory usage.
60
+
61
+ You can install the requirements for a specific Whisper implementation in `requirements-fastWhisper.txt`
62
+ or `requirements-whisper.txt`:
63
+ ```
64
+ pip install -r requirements-fastWhisper.txt
65
+ ```
66
+ And then run the App or the CLI with the `--whisper_implementation fast-whisper` flag:
67
+ ```
68
+ python app.py --whisper_implementation fast-whisper --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
69
+ ```
70
+ You can also select the whisper implementation in `config.json5`:
71
+ ```json5
72
+ {
73
+ "whisper_implementation": "fast-whisper"
74
+ }
75
+ ```
76
+ ### GPU Acceleration
77
+
78
+ In order to use GPU acceleration with Faster Whisper, both CUDA 11.2 and cuDNN 8 must be installed. You may want to install it in a virtual environment like Anaconda.
79
+
80
+ ## Google Colab
81
+
82
+ You can also run this Web UI directly on [Google Colab](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing), if you haven't got a GPU powerful enough to run the larger models.
83
+
84
+ See the [colab documentation](docs/colab.md) for more information.
85
+
86
+ ## Parallel Execution
87
+
88
+ You can also run both the Web-UI or the CLI on multiple GPUs in parallel, using the `vad_parallel_devices` option. This takes a comma-delimited list of
89
+ device IDs (0, 1, etc.) that Whisper should be distributed to and run on concurrently:
90
+ ```
91
+ python cli.py --model large --vad silero-vad --language Japanese \
92
+ --vad_parallel_devices 0,1 "https://www.youtube.com/watch?v=4cICErqqRSM"
93
+ ```
94
+
95
+ Note that this requires a VAD to function properly, otherwise only the first GPU will be used. Though you could use `period-vad` to avoid taking the hit
96
+ of running Silero-Vad, at a slight cost to accuracy.
97
+
98
+ This is achieved by creating N child processes (where N is the number of selected devices), where Whisper is run concurrently. In `app.py`, you can also
99
+ set the `vad_process_timeout` option. This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory.
100
+ The default value is 30 minutes.
101
+
102
+ ```
103
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600
104
+ ```
105
+
106
+ To execute the Silero VAD itself in parallel, use the `vad_cpu_cores` option:
107
+ ```
108
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600 --vad_cpu_cores 4
109
+ ```
110
+
111
+ You may also use `vad_process_timeout` with a single device (`--vad_parallel_devices 0`), if you prefer to always free video memory after a period of time.
112
+
113
+ ### Auto Parallel
114
+
115
+ You can also set `auto_parallel` to `True`. This will set `vad_parallel_devices` to use all the GPU devices on the system, and `vad_cpu_cores` to be equal to the number of
116
+ cores (up to 8):
117
+ ```
118
+ python app.py --input_audio_max_duration -1 --auto_parallel True
119
+ ```
120
+
121
+ # Docker
122
+
123
+ To run it in Docker, first install Docker and optionally the NVIDIA Container Toolkit in order to use the GPU.
124
+ Then either use the GitLab hosted container below, or check out this repository and build an image:
125
+ ```
126
+ sudo docker build -t whisper-webui:1 .
127
+ ```
128
+
129
+ You can then start the WebUI with GPU support like so:
130
+ ```
131
+ sudo docker run -d --gpus=all -p 7860:7860 whisper-webui:1
132
+ ```
133
+
134
+ Leave out "--gpus=all" if you don't have access to a GPU with enough memory, and are fine with running it on the CPU only:
135
+ ```
136
+ sudo docker run -d -p 7860:7860 whisper-webui:1
137
+ ```
138
+
139
+ # GitLab Docker Registry
140
+
141
+ This Docker container is also hosted on GitLab:
142
+
143
+ ```
144
+ sudo docker run -d --gpus=all -p 7860:7860 registry.gitlab.com/aadnk/whisper-webui:latest
145
+ ```
146
+
147
+ ## Custom Arguments
148
+
149
+ You can also pass custom arguments to `app.py` in the Docker container, for instance to be able to use all the GPUs in parallel (replace administrator with your user):
150
+ ```
151
+ sudo docker run -d --gpus all -p 7860:7860 \
152
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
153
+ --mount type=bind,source=/home/administrator/.cache/huggingface,target=/root/.cache/huggingface \
154
+ --restart=on-failure:15 registry.gitlab.com/aadnk/whisper-webui:latest \
155
+ app.py --input_audio_max_duration -1 --server_name 0.0.0.0 --auto_parallel True \
156
+ --default_vad silero-vad --default_model_name large
157
+ ```
158
+
159
+ You can also call `cli.py` the same way:
160
+ ```
161
+ sudo docker run --gpus all \
162
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
163
+ --mount type=bind,source=/home/administrator/.cache/huggingface,target=/root/.cache/huggingface \
164
+ --mount type=bind,source=${PWD},target=/app/data \
165
+ registry.gitlab.com/aadnk/whisper-webui:latest \
166
+ cli.py --model large --auto_parallel True --vad silero-vad \
167
+ --output_dir /app/data /app/data/YOUR-FILE-HERE.mp4
168
+ ```
169
+
170
+ ## Caching
171
+
172
+ Note that the models themselves are currently not included in the Docker images, and will be downloaded on the demand.
173
+ To avoid this, bind the directory /root/.cache/whisper to some directory on the host (for instance /home/administrator/.cache/whisper), where you can (optionally)
174
+ prepopulate the directory with the different Whisper models.
175
+ ```
176
+ sudo docker run -d --gpus=all -p 7860:7860 \
177
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
178
+ registry.gitlab.com/aadnk/whisper-webui:latest
179
+ ```
app-local.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1))
app-network.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions, and make it available on the network
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, server_name="0.0.0.0"))
app-shared.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, share=True))
app.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import math
3
+ from typing import Iterator, Union
4
+ import argparse
5
+
6
+ from io import StringIO
7
+ import os
8
+ import pathlib
9
+ import tempfile
10
+ import zipfile
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ from src.config import ApplicationConfig, VadInitialPromptMode
16
+ from src.hooks.progressListener import ProgressListener
17
+ from src.hooks.subTaskProgressListener import SubTaskProgressListener
18
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
19
+ from src.languages import get_language_names
20
+ from src.modelCache import ModelCache
21
+ from src.source import get_audio_source_collection
22
+ from src.vadParallel import ParallelContext, ParallelTranscription
23
+
24
+ # External programs
25
+ import ffmpeg
26
+
27
+ # UI
28
+ import gradio as gr
29
+
30
+ from src.download import ExceededMaximumDuration, download_url
31
+ from src.utils import slugify, write_srt, write_vtt
32
+ from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
33
+ from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
34
+ from src.whisper.whisperFactory import create_whisper_container
35
+
36
+ # Configure more application defaults in config.json5
37
+
38
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
39
+ MAX_FILE_PREFIX_LENGTH = 17
40
+
41
+ # Limit auto_parallel to a certain number of CPUs (specify vad_cpu_cores to get a higher number)
42
+ MAX_AUTO_CPU_CORES = 8
43
+
44
+ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
45
+
46
+ class VadOptions:
47
+ def __init__(self, vad: str = None, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
48
+ vadInitialPromptMode: Union[VadInitialPromptMode, str] = VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
49
+ self.vad = vad
50
+ self.vadMergeWindow = vadMergeWindow
51
+ self.vadMaxMergeSize = vadMaxMergeSize
52
+ self.vadPadding = vadPadding
53
+ self.vadPromptWindow = vadPromptWindow
54
+ self.vadInitialPromptMode = vadInitialPromptMode if isinstance(vadInitialPromptMode, VadInitialPromptMode) \
55
+ else VadInitialPromptMode.from_string(vadInitialPromptMode)
56
+
57
+ class WhisperTranscriber:
58
+ def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
59
+ vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
60
+ app_config: ApplicationConfig = None):
61
+ self.model_cache = ModelCache()
62
+ self.parallel_device_list = None
63
+ self.gpu_parallel_context = None
64
+ self.cpu_parallel_context = None
65
+ self.vad_process_timeout = vad_process_timeout
66
+ self.vad_cpu_cores = vad_cpu_cores
67
+
68
+ self.vad_model = None
69
+ self.inputAudioMaxDuration = input_audio_max_duration
70
+ self.deleteUploadedFiles = delete_uploaded_files
71
+ self.output_dir = output_dir
72
+
73
+ self.app_config = app_config
74
+
75
+ def set_parallel_devices(self, vad_parallel_devices: str):
76
+ self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
77
+
78
+ def set_auto_parallel(self, auto_parallel: bool):
79
+ if auto_parallel:
80
+ if torch.cuda.is_available():
81
+ self.parallel_device_list = [ str(gpu_id) for gpu_id in range(torch.cuda.device_count())]
82
+
83
+ self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
84
+ print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
85
+
86
+ # Entry function for the simple tab
87
+ def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
88
+ progress=gr.Progress()):
89
+
90
+ vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, self.app_config.vad_initial_prompt_mode)
91
+
92
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions, progress=progress)
93
+
94
+ # Entry function for the full tab
95
+ def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
96
+ vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
97
+ initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
98
+ condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
99
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
100
+ progress=gr.Progress()):
101
+
102
+ # Handle temperature_increment_on_fallback
103
+ if temperature_increment_on_fallback is not None:
104
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
105
+ else:
106
+ temperature = [temperature]
107
+
108
+ vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
109
+
110
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
111
+ initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
112
+ condition_on_previous_text=condition_on_previous_text, fp16=fp16,
113
+ compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
114
+ progress=progress)
115
+
116
+ def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
117
+ vadOptions: VadOptions, progress: gr.Progress = None, **decodeOptions: dict):
118
+ try:
119
+ sources = self.__get_source(urlData, multipleFiles, microphoneData)
120
+
121
+ try:
122
+ selectedLanguage = languageName.lower() if len(languageName) > 0 else None
123
+ selectedModel = modelName if modelName is not None else "base"
124
+
125
+ model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
126
+ model_name=selectedModel, compute_type=self.app_config.compute_type,
127
+ cache=self.model_cache, models=self.app_config.models)
128
+
129
+ # Result
130
+ download = []
131
+ zip_file_lookup = {}
132
+ text = ""
133
+ vtt = ""
134
+
135
+ # Write result
136
+ downloadDirectory = tempfile.mkdtemp()
137
+ source_index = 0
138
+
139
+ outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
140
+
141
+ # Progress
142
+ total_duration = sum([source.get_audio_duration() for source in sources])
143
+ current_progress = 0
144
+
145
+ # A listener that will report progress to Gradio
146
+ root_progress_listener = self._create_progress_listener(progress)
147
+
148
+ # Execute whisper
149
+ for source in sources:
150
+ source_prefix = ""
151
+ source_audio_duration = source.get_audio_duration()
152
+
153
+ if (len(sources) > 1):
154
+ # Prefix (minimum 2 digits)
155
+ source_index += 1
156
+ source_prefix = str(source_index).zfill(2) + "_"
157
+ print("Transcribing ", source.source_path)
158
+
159
+ scaled_progress_listener = SubTaskProgressListener(root_progress_listener,
160
+ base_task_total=total_duration,
161
+ sub_task_start=current_progress,
162
+ sub_task_total=source_audio_duration)
163
+
164
+ # Transcribe
165
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
166
+ filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
167
+
168
+ # Update progress
169
+ current_progress += source_audio_duration
170
+
171
+ source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
172
+
173
+ if len(sources) > 1:
174
+ # Add new line separators
175
+ if (len(source_text) > 0):
176
+ source_text += os.linesep + os.linesep
177
+ if (len(source_vtt) > 0):
178
+ source_vtt += os.linesep + os.linesep
179
+
180
+ # Append file name to source text too
181
+ source_text = source.get_full_name() + ":" + os.linesep + source_text
182
+ source_vtt = source.get_full_name() + ":" + os.linesep + source_vtt
183
+
184
+ # Add to result
185
+ download.extend(source_download)
186
+ text += source_text
187
+ vtt += source_vtt
188
+
189
+ if (len(sources) > 1):
190
+ # Zip files support at least 260 characters, but we'll play it safe and use 200
191
+ zipFilePrefix = slugify(source_prefix + source.get_short_name(max_length=200), allow_unicode=True)
192
+
193
+ # File names in ZIP file can be longer
194
+ for source_download_file in source_download:
195
+ # Get file postfix (after last -)
196
+ filePostfix = os.path.basename(source_download_file).split("-")[-1]
197
+ zip_file_name = zipFilePrefix + "-" + filePostfix
198
+ zip_file_lookup[source_download_file] = zip_file_name
199
+
200
+ # Create zip file from all sources
201
+ if len(sources) > 1:
202
+ downloadAllPath = os.path.join(downloadDirectory, "All_Output-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
203
+
204
+ with zipfile.ZipFile(downloadAllPath, 'w', zipfile.ZIP_DEFLATED) as zip:
205
+ for download_file in download:
206
+ # Get file name from lookup
207
+ zip_file_name = zip_file_lookup.get(download_file, os.path.basename(download_file))
208
+ zip.write(download_file, arcname=zip_file_name)
209
+
210
+ download.insert(0, downloadAllPath)
211
+
212
+ return download, text, vtt
213
+
214
+ finally:
215
+ # Cleanup source
216
+ if self.deleteUploadedFiles:
217
+ for source in sources:
218
+ print("Deleting source file " + source.source_path)
219
+
220
+ try:
221
+ os.remove(source.source_path)
222
+ except Exception as e:
223
+ # Ignore error - it's just a cleanup
224
+ print("Error deleting source file " + source.source_path + ": " + str(e))
225
+
226
+ except ExceededMaximumDuration as e:
227
+ return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
228
+
229
+ def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None,
230
+ vadOptions: VadOptions = VadOptions(),
231
+ progressListener: ProgressListener = None, **decodeOptions: dict):
232
+
233
+ initial_prompt = decodeOptions.pop('initial_prompt', None)
234
+
235
+ if progressListener is None:
236
+ # Default progress listener
237
+ progressListener = ProgressListener()
238
+
239
+ if ('task' in decodeOptions):
240
+ task = decodeOptions.pop('task')
241
+
242
+ # Callable for processing an audio file
243
+ whisperCallable = model.create_callback(language, task, initial_prompt, initial_prompt_mode=vadOptions.vadInitialPromptMode, **decodeOptions)
244
+
245
+ # The results
246
+ if (vadOptions.vad == 'silero-vad'):
247
+ # Silero VAD where non-speech gaps are transcribed
248
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadOptions)
249
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
250
+ elif (vadOptions.vad == 'silero-vad-skip-gaps'):
251
+ # Silero VAD where non-speech gaps are simply ignored
252
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadOptions)
253
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
254
+ elif (vadOptions.vad == 'silero-vad-expand-into-gaps'):
255
+ # Use Silero VAD where speech-segments are expanded into non-speech gaps
256
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadOptions)
257
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
258
+ elif (vadOptions.vad == 'periodic-vad'):
259
+ # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
260
+ # it may create a break in the middle of a sentence, causing some artifacts.
261
+ periodic_vad = VadPeriodicTranscription()
262
+ period_config = PeriodicTranscriptionConfig(periodic_duration=vadOptions.vadMaxMergeSize, max_prompt_window=vadOptions.vadPromptWindow)
263
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
264
+
265
+ else:
266
+ if (self._has_parallel_devices()):
267
+ # Use a simple period transcription instead, as we need to use the parallel context
268
+ periodic_vad = VadPeriodicTranscription()
269
+ period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
270
+
271
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
272
+ else:
273
+ # Default VAD
274
+ result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
275
+
276
+ return result
277
+
278
+ def _create_progress_listener(self, progress: gr.Progress):
279
+ if (progress is None):
280
+ # Dummy progress listener
281
+ return ProgressListener()
282
+
283
+ class ForwardingProgressListener(ProgressListener):
284
+ def __init__(self, progress: gr.Progress):
285
+ self.progress = progress
286
+
287
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
288
+ # From 0 to 1
289
+ self.progress(current / total)
290
+
291
+ def on_finished(self):
292
+ self.progress(1)
293
+
294
+ return ForwardingProgressListener(progress)
295
+
296
+ def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig,
297
+ progressListener: ProgressListener = None):
298
+ if (not self._has_parallel_devices()):
299
+ # No parallel devices, so just run the VAD and Whisper in sequence
300
+ return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
301
+
302
+ gpu_devices = self.parallel_device_list
303
+
304
+ if (gpu_devices is None or len(gpu_devices) == 0):
305
+ # No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
306
+ gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
307
+
308
+ # Create parallel context if needed
309
+ if (self.gpu_parallel_context is None):
310
+ # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
311
+ self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
312
+ # We also need a CPU context for the VAD
313
+ if (self.cpu_parallel_context is None):
314
+ self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
315
+
316
+ parallel_vad = ParallelTranscription()
317
+ return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
318
+ config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
319
+ cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context,
320
+ progress_listener=progressListener)
321
+
322
+ def _has_parallel_devices(self):
323
+ return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
324
+
325
+ def _concat_prompt(self, prompt1, prompt2):
326
+ if (prompt1 is None):
327
+ return prompt2
328
+ elif (prompt2 is None):
329
+ return prompt1
330
+ else:
331
+ return prompt1 + " " + prompt2
332
+
333
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
334
+ # Use Silero VAD
335
+ if (self.vad_model is None):
336
+ self.vad_model = VadSileroTranscription()
337
+
338
+ config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
339
+ max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
340
+ segment_padding_left=vadOptions.vadPadding, segment_padding_right=vadOptions.vadPadding,
341
+ max_prompt_window=vadOptions.vadPromptWindow)
342
+
343
+ return config
344
+
345
+ def write_result(self, result: dict, source_name: str, output_dir: str):
346
+ if not os.path.exists(output_dir):
347
+ os.makedirs(output_dir)
348
+
349
+ text = result["text"]
350
+ language = result["language"]
351
+ languageMaxLineWidth = self.__get_max_line_width(language)
352
+
353
+ print("Max line width " + str(languageMaxLineWidth))
354
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
355
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
356
+
357
+ output_files = []
358
+ output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
359
+ output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
360
+ output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
361
+
362
+ return output_files, text, vtt
363
+
364
+ def clear_cache(self):
365
+ self.model_cache.clear()
366
+ self.vad_model = None
367
+
368
+ def __get_source(self, urlData, multipleFiles, microphoneData):
369
+ return get_audio_source_collection(urlData, multipleFiles, microphoneData, self.inputAudioMaxDuration)
370
+
371
+ def __get_max_line_width(self, language: str) -> int:
372
+ if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
373
+ # Chinese characters and kana are wider, so limit line length to 40 characters
374
+ return 40
375
+ else:
376
+ # TODO: Add more languages
377
+ # 80 latin characters should fit on a 1080p/720p screen
378
+ return 80
379
+
380
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
381
+ segmentStream = StringIO()
382
+
383
+ if format == 'vtt':
384
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
385
+ elif format == 'srt':
386
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
387
+ else:
388
+ raise Exception("Unknown format " + format)
389
+
390
+ segmentStream.seek(0)
391
+ return segmentStream.read()
392
+
393
+ def __create_file(self, text: str, directory: str, fileName: str) -> str:
394
+ # Write the text to a file
395
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
396
+ file.write(text)
397
+
398
+ return file.name
399
+
400
+ def close(self):
401
+ print("Closing parallel contexts")
402
+ self.clear_cache()
403
+
404
+ if (self.gpu_parallel_context is not None):
405
+ self.gpu_parallel_context.close()
406
+ if (self.cpu_parallel_context is not None):
407
+ self.cpu_parallel_context.close()
408
+
409
+
410
+ def create_ui(app_config: ApplicationConfig):
411
+ ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
412
+ app_config.delete_uploaded_files, app_config.output_dir, app_config)
413
+
414
+ # Specify a list of devices to use for parallel processing
415
+ ui.set_parallel_devices(app_config.vad_parallel_devices)
416
+ ui.set_auto_parallel(app_config.auto_parallel)
417
+
418
+ is_whisper = False
419
+
420
+ if app_config.whisper_implementation == "whisper":
421
+ implementation_name = "Whisper"
422
+ is_whisper = True
423
+ elif app_config.whisper_implementation in ["faster-whisper", "faster_whisper"]:
424
+ implementation_name = "Faster Whisper"
425
+ else:
426
+ # Try to convert from camel-case to title-case
427
+ implementation_name = app_config.whisper_implementation.title().replace("_", " ").replace("-", " ")
428
+
429
+ ui_description = implementation_name + " is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
430
+ ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
431
+ ui_description += " as well as speech translation and language identification. "
432
+
433
+ ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
434
+
435
+ # Recommend faster-whisper
436
+ if is_whisper:
437
+ ui_description += "\n\n\n\nFor faster inference on GPU, try [faster-whisper](https://huggingface.co/spaces/aadnk/faster-whisper-webui)."
438
+
439
+ if app_config.input_audio_max_duration > 0:
440
+ ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
441
+
442
+ ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
443
+
444
+ whisper_models = app_config.get_model_names()
445
+
446
+ simple_inputs = lambda : [
447
+ gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
448
+ gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
449
+ gr.Text(label="URL (YouTube, etc.)"),
450
+ gr.File(label="Upload Files", file_count="multiple"),
451
+ gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
452
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
453
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
454
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
455
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
456
+ gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
457
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
458
+ ]
459
+
460
+ simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple, description=ui_description, article=ui_article, inputs=simple_inputs(), outputs=[
461
+ gr.File(label="Download"),
462
+ gr.Text(label="Transcription"),
463
+ gr.Text(label="Segments")
464
+ ])
465
+
466
+ full_description = ui_description + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
467
+
468
+ full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
469
+ *simple_inputs(),
470
+ gr.Dropdown(choices=["prepend_first_segment", "prepend_all_segments"], value=app_config.vad_initial_prompt_mode, label="VAD - Initial Prompt Mode"),
471
+ gr.TextArea(label="Initial Prompt"),
472
+ gr.Number(label="Temperature", value=app_config.temperature),
473
+ gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
474
+ gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
475
+ gr.Number(label="Patience - Zero temperature", value=app_config.patience),
476
+ gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
477
+ gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
478
+ gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
479
+ gr.Checkbox(label="FP16", value=app_config.fp16),
480
+ gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
481
+ gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
482
+ gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
483
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)
484
+ ], outputs=[
485
+ gr.File(label="Download"),
486
+ gr.Text(label="Transcription"),
487
+ gr.Text(label="Segments")
488
+ ])
489
+
490
+ demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
491
+
492
+ # Queue up the demo
493
+ if app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0:
494
+ demo.queue(concurrency_count=app_config.queue_concurrency_count)
495
+
496
+ demo.launch(share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
497
+
498
+ # Clean up
499
+ ui.close()
500
+
501
+ if __name__ == '__main__':
502
+ default_app_config = ApplicationConfig.create_default()
503
+ whisper_models = default_app_config.get_model_names()
504
+
505
+ # Environment variable overrides
506
+ default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
507
+
508
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
509
+ parser.add_argument("--input_audio_max_duration", type=int, default=default_app_config.input_audio_max_duration, \
510
+ help="Maximum audio file length in seconds, or -1 for no limit.") # 600
511
+ parser.add_argument("--share", type=bool, default=default_app_config.share, \
512
+ help="True to share the app on HuggingFace.") # False
513
+ parser.add_argument("--server_name", type=str, default=default_app_config.server_name, \
514
+ help="The host or IP to bind to. If None, bind to localhost.") # None
515
+ parser.add_argument("--server_port", type=int, default=default_app_config.server_port, \
516
+ help="The port to bind to.") # 7860
517
+ parser.add_argument("--queue_concurrency_count", type=int, default=default_app_config.queue_concurrency_count, \
518
+ help="The number of concurrent requests to process.") # 1
519
+ parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=default_app_config.default_model_name, \
520
+ help="The default model name.") # medium
521
+ parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
522
+ help="The default VAD.") # silero-vad
523
+ parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=["prepend_all_segments", "prepend_first_segment"], \
524
+ help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
525
+ parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
526
+ help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
527
+ parser.add_argument("--vad_cpu_cores", type=int, default=default_app_config.vad_cpu_cores, \
528
+ help="The number of CPU cores to use for VAD pre-processing.") # 1
529
+ parser.add_argument("--vad_process_timeout", type=float, default=default_app_config.vad_process_timeout, \
530
+ help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
531
+ parser.add_argument("--auto_parallel", type=bool, default=default_app_config.auto_parallel, \
532
+ help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
533
+ parser.add_argument("--output_dir", "-o", type=str, default=default_app_config.output_dir, \
534
+ help="directory to save the outputs")
535
+ parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
536
+ help="the Whisper implementation to use")
537
+ parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
538
+ help="the compute type to use for inference")
539
+
540
+ args = parser.parse_args().__dict__
541
+
542
+ updated_config = default_app_config.update(**args)
543
+
544
+ create_ui(app_config=updated_config)
cli.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ from urllib.parse import urlparse
5
+ import warnings
6
+ import numpy as np
7
+
8
+ import torch
9
+ from app import VadOptions, WhisperTranscriber
10
+ from src.config import ApplicationConfig, VadInitialPromptMode
11
+ from src.download import download_url
12
+ from src.languages import get_language_names
13
+
14
+ from src.utils import optional_float, optional_int, str2bool
15
+ from src.whisper.whisperFactory import create_whisper_container
16
+
17
+ def cli():
18
+ app_config = ApplicationConfig.create_default()
19
+ whisper_models = app_config.get_model_names()
20
+
21
+ # For the CLI, we fallback to saving the output to the current directory
22
+ output_dir = app_config.output_dir if app_config.output_dir is not None else "."
23
+
24
+ # Environment variable overrides
25
+ default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", app_config.whisper_implementation)
26
+
27
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28
+ parser.add_argument("audio", nargs="+", type=str, \
29
+ help="audio file(s) to transcribe")
30
+ parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
31
+ help="name of the Whisper model to use") # medium
32
+ parser.add_argument("--model_dir", type=str, default=app_config.model_dir, \
33
+ help="the path to save model files; uses ~/.cache/whisper by default")
34
+ parser.add_argument("--device", default=app_config.device, \
35
+ help="device to use for PyTorch inference")
36
+ parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
37
+ help="directory to save the outputs")
38
+ parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
39
+ help="whether to print out the progress and debug messages")
40
+ parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
41
+ help="the Whisper implementation to use")
42
+
43
+ parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
44
+ help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
45
+ parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(get_language_names()), \
46
+ help="language spoken in the audio, specify None to perform language detection")
47
+
48
+ parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
49
+ help="The voice activity detection algorithm to use") # silero-vad
50
+ parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=["prepend_all_segments", "prepend_first_segment"], \
51
+ help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
52
+ parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
53
+ help="The window size (in seconds) to merge voice segments")
54
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
55
+ help="The maximum size (in seconds) of a voice segment")
56
+ parser.add_argument("--vad_padding", type=optional_float, default=app_config.vad_padding, \
57
+ help="The padding (in seconds) to add to each voice segment")
58
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=app_config.vad_prompt_window, \
59
+ help="The window size of the prompt to pass to Whisper")
60
+ parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
61
+ help="The number of CPU cores to use for VAD pre-processing.") # 1
62
+ parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
63
+ help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
64
+ parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
65
+ help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
66
+
67
+ parser.add_argument("--temperature", type=float, default=app_config.temperature, \
68
+ help="temperature to use for sampling")
69
+ parser.add_argument("--best_of", type=optional_int, default=app_config.best_of, \
70
+ help="number of candidates when sampling with non-zero temperature")
71
+ parser.add_argument("--beam_size", type=optional_int, default=app_config.beam_size, \
72
+ help="number of beams in beam search, only applicable when temperature is zero")
73
+ parser.add_argument("--patience", type=float, default=app_config.patience, \
74
+ help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
75
+ parser.add_argument("--length_penalty", type=float, default=app_config.length_penalty, \
76
+ help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
77
+
78
+ parser.add_argument("--suppress_tokens", type=str, default=app_config.suppress_tokens, \
79
+ help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
80
+ parser.add_argument("--initial_prompt", type=str, default=app_config.initial_prompt, \
81
+ help="optional text to provide as a prompt for the first window.")
82
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=app_config.condition_on_previous_text, \
83
+ help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
84
+ parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
85
+ help="whether to perform inference in fp16; True by default")
86
+ parser.add_argument("--compute_type", type=str, default=app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
87
+ help="the compute type to use for inference")
88
+
89
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
90
+ help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
91
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=app_config.compression_ratio_threshold, \
92
+ help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
93
+ parser.add_argument("--logprob_threshold", type=optional_float, default=app_config.logprob_threshold, \
94
+ help="if the average log probability is lower than this value, treat the decoding as failed")
95
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
96
+ help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
97
+
98
+ args = parser.parse_args().__dict__
99
+ model_name: str = args.pop("model")
100
+ model_dir: str = args.pop("model_dir")
101
+ output_dir: str = args.pop("output_dir")
102
+ device: str = args.pop("device")
103
+ os.makedirs(output_dir, exist_ok=True)
104
+
105
+ whisper_implementation = args.pop("whisper_implementation")
106
+ print(f"Using {whisper_implementation} for Whisper")
107
+
108
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
109
+ warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
110
+ args["language"] = "en"
111
+
112
+ temperature = args.pop("temperature")
113
+ temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
114
+ if temperature_increment_on_fallback is not None:
115
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
116
+ else:
117
+ temperature = [temperature]
118
+
119
+ vad = args.pop("vad")
120
+ vad_initial_prompt_mode = args.pop("vad_initial_prompt_mode")
121
+ vad_merge_window = args.pop("vad_merge_window")
122
+ vad_max_merge_size = args.pop("vad_max_merge_size")
123
+ vad_padding = args.pop("vad_padding")
124
+ vad_prompt_window = args.pop("vad_prompt_window")
125
+ vad_cpu_cores = args.pop("vad_cpu_cores")
126
+ auto_parallel = args.pop("auto_parallel")
127
+
128
+ compute_type = args.pop("compute_type")
129
+
130
+ transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
131
+ transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
132
+ transcriber.set_auto_parallel(auto_parallel)
133
+
134
+ model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
135
+ device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
136
+
137
+ if (transcriber._has_parallel_devices()):
138
+ print("Using parallel devices:", transcriber.parallel_device_list)
139
+
140
+ for audio_path in args.pop("audio"):
141
+ sources = []
142
+
143
+ # Detect URL and download the audio
144
+ if (uri_validator(audio_path)):
145
+ # Download from YouTube/URL directly
146
+ for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
147
+ source_name = os.path.basename(source_path)
148
+ sources.append({ "path": source_path, "name": source_name })
149
+ else:
150
+ sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
151
+
152
+ for source in sources:
153
+ source_path = source["path"]
154
+ source_name = source["name"]
155
+
156
+ vadOptions = VadOptions(vad, vad_merge_window, vad_max_merge_size, vad_padding, vad_prompt_window,
157
+ VadInitialPromptMode.from_string(vad_initial_prompt_mode))
158
+
159
+ result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
160
+
161
+ transcriber.write_result(result, source_name, output_dir)
162
+
163
+ transcriber.close()
164
+
165
+ def uri_validator(x):
166
+ try:
167
+ result = urlparse(x)
168
+ return all([result.scheme, result.netloc])
169
+ except:
170
+ return False
171
+
172
+ if __name__ == '__main__':
173
+ cli()
config.json5 ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": [
3
+ // Configuration for the built-in models. You can remove any of these
4
+ // if you don't want to use the default models.
5
+ {
6
+ "name": "tiny",
7
+ "url": "tiny"
8
+ },
9
+ {
10
+ "name": "base",
11
+ "url": "base"
12
+ },
13
+ {
14
+ "name": "small",
15
+ "url": "small"
16
+ },
17
+ {
18
+ "name": "medium",
19
+ "url": "medium"
20
+ },
21
+ {
22
+ "name": "large",
23
+ "url": "large"
24
+ },
25
+ {
26
+ "name": "large-v2",
27
+ "url": "large-v2"
28
+ },
29
+ // Uncomment to add custom Japanese models
30
+ //{
31
+ // "name": "whisper-large-v2-mix-jp",
32
+ // "url": "vumichien/whisper-large-v2-mix-jp",
33
+ // // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
34
+ // // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
35
+ // "type": "huggingface",
36
+ //},
37
+ //{
38
+ // "name": "local-model",
39
+ // "url": "path/to/local/model",
40
+ //},
41
+ //{
42
+ // "name": "remote-model",
43
+ // "url": "https://example.com/path/to/model",
44
+ //}
45
+ ],
46
+ // Configuration options that will be used if they are not specified in the command line arguments.
47
+
48
+ // * WEBUI options *
49
+
50
+ // Maximum audio file length in seconds, or -1 for no limit. Ignored by CLI.
51
+ "input_audio_max_duration": 600,
52
+ // True to share the app on HuggingFace.
53
+ "share": false,
54
+ // The host or IP to bind to. If None, bind to localhost.
55
+ "server_name": null,
56
+ // The port to bind to.
57
+ "server_port": 7860,
58
+ // The number of workers to use for the web server. Use -1 to disable queueing.
59
+ "queue_concurrency_count": 1,
60
+ // Whether or not to automatically delete all uploaded files, to save disk space
61
+ "delete_uploaded_files": true,
62
+
63
+ // * General options *
64
+
65
+ // The default implementation to use for Whisper. Can be "whisper" or "faster-whisper".
66
+ // Note that you must either install the requirements for faster-whisper (requirements-fasterWhisper.txt)
67
+ // or whisper (requirements.txt)
68
+ "whisper_implementation": "faster-whisper",
69
+
70
+ // The default model name.
71
+ "default_model_name": "medium",
72
+ // The default VAD.
73
+ "default_vad": "silero-vad",
74
+ // A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
75
+ "vad_parallel_devices": "",
76
+ // The number of CPU cores to use for VAD pre-processing.
77
+ "vad_cpu_cores": 1,
78
+ // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
79
+ "vad_process_timeout": 1800,
80
+ // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
81
+ "auto_parallel": false,
82
+ // Directory to save the outputs (CLI will use the current directory if not specified)
83
+ "output_dir": null,
84
+ // The path to save model files; uses ~/.cache/whisper by default
85
+ "model_dir": null,
86
+ // Device to use for PyTorch inference, or Null to use the default device
87
+ "device": null,
88
+ // Whether to print out the progress and debug messages
89
+ "verbose": true,
90
+ // Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
91
+ "task": "transcribe",
92
+ // Language spoken in the audio, specify None to perform language detection
93
+ "language": null,
94
+ // The window size (in seconds) to merge voice segments
95
+ "vad_merge_window": 5,
96
+ // The maximum size (in seconds) of a voice segment
97
+ "vad_max_merge_size": 30,
98
+ // The padding (in seconds) to add to each voice segment
99
+ "vad_padding": 1,
100
+ // Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)
101
+ "vad_initial_prompt_mode": "prepend_first_segment",
102
+ // The window size of the prompt to pass to Whisper
103
+ "vad_prompt_window": 3,
104
+ // Temperature to use for sampling
105
+ "temperature": 0,
106
+ // Number of candidates when sampling with non-zero temperature
107
+ "best_of": 5,
108
+ // Number of beams in beam search, only applicable when temperature is zero
109
+ "beam_size": 5,
110
+ // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
111
+ "patience": 1,
112
+ // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
113
+ "length_penalty": null,
114
+ // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
115
+ "suppress_tokens": "-1",
116
+ // Optional text to provide as a prompt for the first window
117
+ "initial_prompt": null,
118
+ // If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop
119
+ "condition_on_previous_text": true,
120
+ // Whether to perform inference in fp16; True by default
121
+ "fp16": true,
122
+ // The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
123
+ "compute_type": "auto",
124
+ // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
125
+ "temperature_increment_on_fallback": 0.2,
126
+ // If the gzip compression ratio is higher than this value, treat the decoding as failed
127
+ "compression_ratio_threshold": 2.4,
128
+ // If the average log probability is lower than this value, treat the decoding as failed
129
+ "logprob_threshold": -1.0,
130
+ // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
131
+ "no_speech_threshold": 0.6
132
+ }
dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # docker build -t whisper-webui --build-arg WHISPER_IMPLEMENTATION=whisper .
2
+
3
+ FROM huggingface/transformers-pytorch-gpu
4
+ EXPOSE 7860
5
+
6
+ ARG WHISPER_IMPLEMENTATION=whisper
7
+ ENV WHISPER_IMPLEMENTATION=${WHISPER_IMPLEMENTATION}
8
+
9
+ ADD . /opt/whisper-webui/
10
+
11
+ # Latest version of transformers-pytorch-gpu seems to lack tk.
12
+ # Further, pip install fails, so we must upgrade pip first.
13
+ RUN apt-get -y install python3-tk
14
+ RUN python3 -m pip install --upgrade pip
15
+
16
+ RUN if [ "${WHISPER_IMPLEMENTATION}" = "whisper" ]; then \
17
+ python3 -m pip install -r /opt/whisper-webui/requirements-whisper.txt; \
18
+ else \
19
+ python3 -m pip install -r /opt/whisper-webui/requirements-fasterWhisper.txt; \
20
+ fi
21
+
22
+ # Note: Models will be downloaded on demand to the directory /root/.cache/whisper.
23
+ # You can also bind this directory in the container to somewhere on the host.
24
+
25
+ # To be able to see logs in real time
26
+ ENV PYTHONUNBUFFERED=1
27
+
28
+ WORKDIR /opt/whisper-webui/
29
+ ENTRYPOINT ["python3"]
30
+ CMD ["app.py", "--input_audio_max_duration", "-1", "--server_name", "0.0.0.0", "--auto_parallel", "True"]
docs/colab.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running Whisper on Google Colab
2
+
3
+ If you don't have a decent GPU or any experience in running command-line applications, you might want to try this Google Colab instead:
4
+
5
+ * [Google Colab - Whisper WebUI GPU](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing)
6
+ * [Screenshots](https://imgur.com/a/ZfY6uBO)
7
+
8
+ The runtime (Runtime -> Change runtime type -> Hardware accelerator) should already be set top GPU. But if not, change it to GPU.
9
+
10
+ Then, sign in to Google if you haven't already. Next, click on "Connect" at the top right.
11
+
12
+ Under "Checking out WebUI from Git", click on the [play icon](https://imgur.com/a/81gOLyD) that appears in "[ ]" at the left. If you get a warning, click "Run anyway".
13
+
14
+ After this step has completed, it should be get a green check mark. Then move on to the next section under "Installing dependencies", and click in "[ ]" again. This might take approximately 30 seconds.
15
+
16
+ Once this has completed, scroll down to the "Run WebUI" section, and click on "[ ]". This will launch the WebUI in a shared link (expires in 72 hours). To open the UI, click on the link next to "Running on public URL", which will be something like https://12xxx.gradio.app/
17
+
18
+ The audio length in this version is not restricted, and it will run much faster as it is backed by a GPU. You can also run it using the "Large" model. Also note that it might take some time to start the model the first time, as it may need to download a 2.8 GB file on Google's servers.
19
+
20
+ Once you're done, you can close the WebUI session by clicking the animated close button under "Run WebUI". You can also do this if you encounter any errors and need to restart the UI. You should also go to "Manage Sessions" and terminate the session, otherwise you may end up using all your free compute credits.
docs/options.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard Options
2
+ To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
3
+ supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
4
+ in the file selector to select any file type, including video files) or use the microphone.
5
+
6
+ For longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option, especially if you are using the `large-v1` model. Note that `large-v2` is a lot more forgiving, but you may still want to use a VAD with a slightly higher "VAD - Max Merge Size (s)" (60 seconds or more).
7
+
8
+ ## Model
9
+ Select the model that Whisper will use to transcribe the audio:
10
+
11
+ | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
12
+ |-----------|------------|--------------------|--------------------|---------------|----------------|
13
+ | tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
14
+ | base | 74 M | base.en | base | ~1 GB | ~16x |
15
+ | small | 244 M | small.en | small | ~2 GB | ~6x |
16
+ | medium | 769 M | medium.en | medium | ~5 GB | ~2x |
17
+ | large | 1550 M | N/A | large | ~10 GB | 1x |
18
+ | large-v2 | 1550 M | N/A | large | ~10 GB | 1x |
19
+
20
+ ## Language
21
+
22
+ Select the language, or leave it empty for Whisper to automatically detect it.
23
+
24
+ Note that if the selected language and the language in the audio differs, Whisper may start to translate the audio to the selected
25
+ language. For instance, if the audio is in English but you select Japaneese, the model may translate the audio to Japanese.
26
+
27
+ ## Inputs
28
+ The options "URL (YouTube, etc.)", "Upload Files" or "Micriphone Input" allows you to send an audio input to the model.
29
+
30
+ ### Multiple Files
31
+ Note that the UI will only process either the given URL or the upload files (including microphone) - not both.
32
+
33
+ But you can upload multiple files either through the "Upload files" option, or as a playlist on YouTube. Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section. When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
34
+
35
+ ## Task
36
+ Select the task - either "transcribe" to transcribe the audio to text, or "translate" to translate it to English.
37
+
38
+ ## Vad
39
+ Using a VAD will improve the timing accuracy of each transcribed line, as well as prevent Whisper getting into an infinite
40
+ loop detecting the same sentence over and over again. The downside is that this may be at a cost to text accuracy, especially
41
+ with regards to unique words or names that appear in the audio. You can compensate for this by increasing the prompt window.
42
+
43
+ Note that English is very well handled by Whisper, and it's less susceptible to issues surrounding bad timings and infinite loops.
44
+ So you may only need to use a VAD for other languages, such as Japanese, or when the audio is very long.
45
+
46
+ * none
47
+ * Run whisper on the entire audio input
48
+ * silero-vad
49
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Whisper is also run
50
+ on the gaps between each speech section, by either expanding the section up to the max merge size, or running Whisper independently
51
+ on the non-speech section.
52
+ * silero-vad-expand-into-gaps
53
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Each spech section will be expanded
54
+ such that they cover any adjacent non-speech sections. For instance, if an audio file of one minute contains the speech sections
55
+ 00:00 - 00:10 (A) and 00:30 - 00:40 (B), the first section (A) will be expanded to 00:00 - 00:30, and (B) will be expanded to 00:30 - 00:60.
56
+ * silero-vad-skip-gaps
57
+ * As above, but sections that doesn't contain speech according to Silero will be skipped. This will be slightly faster, but
58
+ may cause dialogue to be skipped.
59
+ * periodic-vad
60
+ * Create sections of speech every 'VAD - Max Merge Size' seconds. This is very fast and simple, but will potentially break
61
+ a sentence or word in two.
62
+
63
+ ## VAD - Merge Window
64
+ If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged.
65
+
66
+ ## VAD - Max Merge Size (s)
67
+ Disables merging of adjacent speech sections if they are this number of seconds long.
68
+
69
+ ## VAD - Padding (s)
70
+ The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number
71
+ larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of
72
+ a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp
73
+ to each transcribed line. The default value is 1 second.
74
+
75
+ ## VAD - Prompt Window (s)
76
+ The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this
77
+ number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at
78
+ 10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds).
79
+
80
+ Note that detected lines in gaps between speech sections will not be included in the prompt
81
+ (if silero-vad or silero-vad-expand-into-gaps) is used.
82
+
83
+ # Command Line Options
84
+
85
+ Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
86
+ CPU/GPU cores, the default model name/VAD and so on. Consult the README in the root folder for more information.
87
+
88
+ # Additional Options
89
+
90
+ In addition to the above, there's also a "Full" options interface that allows you to set all the options available in the Whisper
91
+ model. The options are as follows:
92
+
93
+ ## Initial Prompt
94
+ Optional text to provide as a prompt for the first 30 seconds window. Whisper will attempt to use this as a starting point for the transcription, but you can
95
+ also get creative and specify a style or format for the output of the transcription.
96
+
97
+ For instance, if you use the prompt "hello how is it going always use lowercase no punctuation goodbye one two three start stop i you me they", Whisper will
98
+ be biased to output lower capital letters and no punctuation, and may also be biased to output the words in the prompt more often.
99
+
100
+ ## Temperature
101
+ The temperature to use when sampling. Default is 0 (zero). A higher temperature will result in more random output, while a lower temperature will be more deterministic.
102
+
103
+ ## Best Of - Non-zero temperature
104
+ The number of candidates to sample from when sampling with non-zero temperature. Default is 5.
105
+
106
+ ## Beam Size - Zero temperature
107
+ The number of beams to use in beam search when sampling with zero temperature. Default is 5.
108
+
109
+ ## Patience - Zero temperature
110
+ The patience value to use in beam search when sampling with zero temperature. As in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search.
111
+
112
+ ## Length Penalty - Any temperature
113
+ The token length penalty coefficient (alpha) to use when sampling with any temperature. As in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.
114
+
115
+ ## Suppress Tokens - Comma-separated list of token IDs
116
+ A comma-separated list of token IDs to suppress during sampling. The default value of "-1" will suppress most special characters except common punctuations.
117
+
118
+ ## Condition on previous text
119
+ If True, provide the previous output of the model as a prompt for the next window. Disabling this may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop.
120
+
121
+ ## FP16
122
+ Whether to perform inference in fp16. True by default.
123
+
124
+ ## Temperature increment on fallback
125
+ The temperature to increase when falling back when the decoding fails to meet either of the thresholds below. Default is 0.2.
126
+
127
+ ## Compression ratio threshold
128
+ If the gzip compression ratio is higher than this value, treat the decoding as failed. Default is 2.4.
129
+
130
+ ## Logprob threshold
131
+ If the average log probability is lower than this value, treat the decoding as failed. Default is -1.0.
132
+
133
+ ## No speech threshold
134
+ If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
docs/windows/install_win10_win11.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b9f4ed547d6534411c17da1ea56707d2ec6e812611b1cbd3098756d5cbb8084
3
+ size 3378789
requirements-fasterWhisper.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ctranslate2
2
+ faster-whisper
3
+ ffmpeg-python==0.2.0
4
+ gradio==3.23.0
5
+ yt-dlp
6
+ json5
7
+ torch
8
+ torchaudio
9
+ more_itertools
requirements-whisper.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ git+https://github.com/openai/whisper.git
3
+ transformers
4
+ ffmpeg-python==0.2.0
5
+ gradio==3.23.0
6
+ yt-dlp
7
+ torchaudio
8
+ altair
9
+ json5
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ctranslate2
2
+ faster-whisper
3
+ ffmpeg-python==0.2.0
4
+ gradio==3.23.0
5
+ yt-dlp
6
+ json5
7
+ torch
8
+ torchaudio
9
+ more_itertools
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import urllib
3
+
4
+ import os
5
+ from typing import List
6
+ from urllib.parse import urlparse
7
+ import json5
8
+ import torch
9
+
10
+ from tqdm import tqdm
11
+
12
+ class ModelConfig:
13
+ def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
14
+ """
15
+ Initialize a model configuration.
16
+
17
+ name: Name of the model
18
+ url: URL to download the model from
19
+ path: Path to the model file. If not set, the model will be downloaded from the URL.
20
+ type: Type of model. Can be whisper or huggingface.
21
+ """
22
+ self.name = name
23
+ self.url = url
24
+ self.path = path
25
+ self.type = type
26
+
27
+ class VadInitialPromptMode(Enum):
28
+ PREPEND_ALL_SEGMENTS = 1
29
+ PREPREND_FIRST_SEGMENT = 2
30
+
31
+ @staticmethod
32
+ def from_string(s: str):
33
+ normalized = s.lower() if s is not None else None
34
+
35
+ if normalized == "prepend_all_segments":
36
+ return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
37
+ elif normalized == "prepend_first_segment":
38
+ return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
39
+ else:
40
+ raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
41
+
42
+ class ApplicationConfig:
43
+ def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
44
+ share: bool = False, server_name: str = None, server_port: int = 7860,
45
+ queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
46
+ whisper_implementation: str = "whisper",
47
+ default_model_name: str = "medium", default_vad: str = "silero-vad",
48
+ vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
49
+ auto_parallel: bool = False, output_dir: str = None,
50
+ model_dir: str = None, device: str = None,
51
+ verbose: bool = True, task: str = "transcribe", language: str = None,
52
+ vad_initial_prompt_mode: str = "prepend_first_segment ",
53
+ vad_merge_window: float = 5, vad_max_merge_size: float = 30,
54
+ vad_padding: float = 1, vad_prompt_window: float = 3,
55
+ temperature: float = 0, best_of: int = 5, beam_size: int = 5,
56
+ patience: float = None, length_penalty: float = None,
57
+ suppress_tokens: str = "-1", initial_prompt: str = None,
58
+ condition_on_previous_text: bool = True, fp16: bool = True,
59
+ compute_type: str = "float16",
60
+ temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
61
+ logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
62
+
63
+ self.models = models
64
+
65
+ # WebUI settings
66
+ self.input_audio_max_duration = input_audio_max_duration
67
+ self.share = share
68
+ self.server_name = server_name
69
+ self.server_port = server_port
70
+ self.queue_concurrency_count = queue_concurrency_count
71
+ self.delete_uploaded_files = delete_uploaded_files
72
+
73
+ self.whisper_implementation = whisper_implementation
74
+ self.default_model_name = default_model_name
75
+ self.default_vad = default_vad
76
+ self.vad_parallel_devices = vad_parallel_devices
77
+ self.vad_cpu_cores = vad_cpu_cores
78
+ self.vad_process_timeout = vad_process_timeout
79
+ self.auto_parallel = auto_parallel
80
+ self.output_dir = output_dir
81
+
82
+ self.model_dir = model_dir
83
+ self.device = device
84
+ self.verbose = verbose
85
+ self.task = task
86
+ self.language = language
87
+ self.vad_initial_prompt_mode = vad_initial_prompt_mode
88
+ self.vad_merge_window = vad_merge_window
89
+ self.vad_max_merge_size = vad_max_merge_size
90
+ self.vad_padding = vad_padding
91
+ self.vad_prompt_window = vad_prompt_window
92
+ self.temperature = temperature
93
+ self.best_of = best_of
94
+ self.beam_size = beam_size
95
+ self.patience = patience
96
+ self.length_penalty = length_penalty
97
+ self.suppress_tokens = suppress_tokens
98
+ self.initial_prompt = initial_prompt
99
+ self.condition_on_previous_text = condition_on_previous_text
100
+ self.fp16 = fp16
101
+ self.compute_type = compute_type
102
+ self.temperature_increment_on_fallback = temperature_increment_on_fallback
103
+ self.compression_ratio_threshold = compression_ratio_threshold
104
+ self.logprob_threshold = logprob_threshold
105
+ self.no_speech_threshold = no_speech_threshold
106
+
107
+ def get_model_names(self):
108
+ return [ x.name for x in self.models ]
109
+
110
+ def update(self, **new_values):
111
+ result = ApplicationConfig(**self.__dict__)
112
+
113
+ for key, value in new_values.items():
114
+ setattr(result, key, value)
115
+ return result
116
+
117
+ @staticmethod
118
+ def create_default(**kwargs):
119
+ app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
120
+
121
+ # Update with kwargs
122
+ if len(kwargs) > 0:
123
+ app_config = app_config.update(**kwargs)
124
+ return app_config
125
+
126
+ @staticmethod
127
+ def parse_file(config_path: str):
128
+ import json5
129
+
130
+ with open(config_path, "r") as f:
131
+ # Load using json5
132
+ data = json5.load(f)
133
+ data_models = data.pop("models", [])
134
+
135
+ models = [ ModelConfig(**x) for x in data_models ]
136
+
137
+ return ApplicationConfig(models, **data)
src/conversion/hf_converter.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets
2
+
3
+ from copy import deepcopy
4
+ import torch
5
+
6
+ WHISPER_MAPPING = {
7
+ "layers": "blocks",
8
+ "fc1": "mlp.0",
9
+ "fc2": "mlp.2",
10
+ "final_layer_norm": "mlp_ln",
11
+ "layers": "blocks",
12
+ ".self_attn.q_proj": ".attn.query",
13
+ ".self_attn.k_proj": ".attn.key",
14
+ ".self_attn.v_proj": ".attn.value",
15
+ ".self_attn_layer_norm": ".attn_ln",
16
+ ".self_attn.out_proj": ".attn.out",
17
+ ".encoder_attn.q_proj": ".cross_attn.query",
18
+ ".encoder_attn.k_proj": ".cross_attn.key",
19
+ ".encoder_attn.v_proj": ".cross_attn.value",
20
+ ".encoder_attn_layer_norm": ".cross_attn_ln",
21
+ ".encoder_attn.out_proj": ".cross_attn.out",
22
+ "decoder.layer_norm.": "decoder.ln.",
23
+ "encoder.layer_norm.": "encoder.ln_post.",
24
+ "embed_tokens": "token_embedding",
25
+ "encoder.embed_positions.weight": "encoder.positional_embedding",
26
+ "decoder.embed_positions.weight": "decoder.positional_embedding",
27
+ "layer_norm": "ln_post",
28
+ }
29
+
30
+
31
+ def rename_keys(s_dict):
32
+ keys = list(s_dict.keys())
33
+ for key in keys:
34
+ new_key = key
35
+ for k, v in WHISPER_MAPPING.items():
36
+ if k in key:
37
+ new_key = new_key.replace(k, v)
38
+
39
+ print(f"{key} -> {new_key}")
40
+
41
+ s_dict[new_key] = s_dict.pop(key)
42
+ return s_dict
43
+
44
+
45
+ def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
46
+ from transformers import WhisperForConditionalGeneration
47
+ transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
48
+ config = transformer_model.config
49
+
50
+ # first build dims
51
+ dims = {
52
+ 'n_mels': config.num_mel_bins,
53
+ 'n_vocab': config.vocab_size,
54
+ 'n_audio_ctx': config.max_source_positions,
55
+ 'n_audio_state': config.d_model,
56
+ 'n_audio_head': config.encoder_attention_heads,
57
+ 'n_audio_layer': config.encoder_layers,
58
+ 'n_text_ctx': config.max_target_positions,
59
+ 'n_text_state': config.d_model,
60
+ 'n_text_head': config.decoder_attention_heads,
61
+ 'n_text_layer': config.decoder_layers
62
+ }
63
+
64
+ state_dict = deepcopy(transformer_model.model.state_dict())
65
+ state_dict = rename_keys(state_dict)
66
+
67
+ torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
src/download.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import mkdtemp
2
+ from typing import List
3
+ from yt_dlp import YoutubeDL
4
+
5
+ import yt_dlp
6
+ from yt_dlp.postprocessor import PostProcessor
7
+
8
+ class FilenameCollectorPP(PostProcessor):
9
+ def __init__(self):
10
+ super(FilenameCollectorPP, self).__init__(None)
11
+ self.filenames = []
12
+
13
+ def run(self, information):
14
+ self.filenames.append(information["filepath"])
15
+ return [], information
16
+
17
+ def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
18
+ try:
19
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
20
+ except yt_dlp.utils.DownloadError as e:
21
+ # In case of an OS error, try again with a different output template
22
+ if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
23
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
24
+ pass
25
+
26
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
27
+ # Create a temporary directory to store the downloaded files
28
+ if destinationDirectory is None:
29
+ destinationDirectory = mkdtemp()
30
+
31
+ ydl_opts = {
32
+ "format": "bestaudio/best",
33
+ 'paths': {
34
+ 'home': destinationDirectory
35
+ }
36
+ }
37
+ if (playlistItems):
38
+ ydl_opts['playlist_items'] = playlistItems
39
+
40
+ # Add output template if specified
41
+ if outputTemplate:
42
+ ydl_opts['outtmpl'] = outputTemplate
43
+
44
+ filename_collector = FilenameCollectorPP()
45
+
46
+ with YoutubeDL(ydl_opts) as ydl:
47
+ if maxDuration and maxDuration > 0:
48
+ info = ydl.extract_info(url, download=False)
49
+ entries = "entries" in info and info["entries"] or [info]
50
+
51
+ total_duration = 0
52
+
53
+ # Compute total duration
54
+ for entry in entries:
55
+ total_duration += float(entry["duration"])
56
+
57
+ if total_duration >= maxDuration:
58
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=maxDuration, message="Video is too long")
59
+
60
+ ydl.add_post_processor(filename_collector)
61
+ ydl.download([url])
62
+
63
+ if len(filename_collector.filenames) <= 0:
64
+ raise Exception("Cannot download " + url)
65
+
66
+ result = []
67
+
68
+ for filename in filename_collector.filenames:
69
+ result.append(filename)
70
+ print("Downloaded " + filename)
71
+
72
+ return result
73
+
74
+ class ExceededMaximumDuration(Exception):
75
+ def __init__(self, videoDuration, maxDuration, message):
76
+ self.videoDuration = videoDuration
77
+ self.maxDuration = maxDuration
78
+ super().__init__(message)
src/hooks/progressListener.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ class ProgressListener:
4
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
5
+ self.total = total
6
+
7
+ def on_finished(self):
8
+ pass
src/hooks/subTaskProgressListener.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.hooks.progressListener import ProgressListener
2
+
3
+ from typing import Union
4
+
5
+ class SubTaskProgressListener(ProgressListener):
6
+ """
7
+ A sub task listener that reports the progress of a sub task to a base task listener
8
+ Parameters
9
+ ----------
10
+ base_task_listener : ProgressListener
11
+ The base progress listener to accumulate overall progress in.
12
+ base_task_total : float
13
+ The maximum total progress that will be reported to the base progress listener.
14
+ sub_task_start : float
15
+ The starting progress of a sub task, in respect to the base progress listener.
16
+ sub_task_total : float
17
+ The total amount of progress a sub task will report to the base progress listener.
18
+ """
19
+ def __init__(
20
+ self,
21
+ base_task_listener: ProgressListener,
22
+ base_task_total: float,
23
+ sub_task_start: float,
24
+ sub_task_total: float,
25
+ ):
26
+ self.base_task_listener = base_task_listener
27
+ self.base_task_total = base_task_total
28
+ self.sub_task_start = sub_task_start
29
+ self.sub_task_total = sub_task_total
30
+
31
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
32
+ sub_task_progress_frac = current / total
33
+ sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
34
+ self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
35
+
36
+ def on_finished(self):
37
+ self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
src/hooks/whisperProgressHook.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import threading
3
+ from typing import List, Union
4
+ import tqdm
5
+
6
+ from src.hooks.progressListener import ProgressListener
7
+
8
+ class ProgressListenerHandle:
9
+ def __init__(self, listener: ProgressListener):
10
+ self.listener = listener
11
+
12
+ def __enter__(self):
13
+ register_thread_local_progress_listener(self.listener)
14
+
15
+ def __exit__(self, exc_type, exc_val, exc_tb):
16
+ unregister_thread_local_progress_listener(self.listener)
17
+
18
+ if exc_type is None:
19
+ self.listener.on_finished()
20
+
21
+ class _CustomProgressBar(tqdm.tqdm):
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._current = self.n # Set the initial value
25
+
26
+ def update(self, n):
27
+ super().update(n)
28
+ # Because the progress bar might be disabled, we need to manually update the progress
29
+ self._current += n
30
+
31
+ # Inform listeners
32
+ listeners = _get_thread_local_listeners()
33
+
34
+ for listener in listeners:
35
+ listener.on_progress(self._current, self.total)
36
+
37
+ _thread_local = threading.local()
38
+
39
+ def _get_thread_local_listeners():
40
+ if not hasattr(_thread_local, 'listeners'):
41
+ _thread_local.listeners = []
42
+ return _thread_local.listeners
43
+
44
+ _hooked = False
45
+
46
+ def init_progress_hook():
47
+ global _hooked
48
+
49
+ if _hooked:
50
+ return
51
+
52
+ # Inject into tqdm.tqdm of Whisper, so we can see progress
53
+ import whisper.transcribe
54
+ transcribe_module = sys.modules['whisper.transcribe']
55
+ transcribe_module.tqdm.tqdm = _CustomProgressBar
56
+ _hooked = True
57
+
58
+ def register_thread_local_progress_listener(progress_listener: ProgressListener):
59
+ # This is a workaround for the fact that the progress bar is not exposed in the API
60
+ init_progress_hook()
61
+
62
+ listeners = _get_thread_local_listeners()
63
+ listeners.append(progress_listener)
64
+
65
+ def unregister_thread_local_progress_listener(progress_listener: ProgressListener):
66
+ listeners = _get_thread_local_listeners()
67
+
68
+ if progress_listener in listeners:
69
+ listeners.remove(progress_listener)
70
+
71
+ def create_progress_listener_handle(progress_listener: ProgressListener):
72
+ return ProgressListenerHandle(progress_listener)
73
+
74
+ # Example usage
75
+ if __name__ == '__main__':
76
+ class PrintingProgressListener:
77
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
78
+ print(f"Progress: {current}/{total}")
79
+
80
+ def on_finished(self):
81
+ print("Finished")
82
+
83
+ import whisper
84
+ model = whisper.load_model("medium")
85
+
86
+ with create_progress_listener_handle(PrintingProgressListener()) as listener:
87
+ # Set verbose to None to disable the progress bar, as we are using our own
88
+ result = model.transcribe("J:\\Dev\\OpenAI\\whisper\\tests\\Noriko\\out.mka", language="Japanese", fp16=False, verbose=None)
89
+ print(result)
90
+
91
+ print("Done")
src/languages.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Language():
2
+ def __init__(self, code, name):
3
+ self.code = code
4
+ self.name = name
5
+
6
+ def __str__(self):
7
+ return "Language(code={}, name={})".format(self.code, self.name)
8
+
9
+ LANGUAGES = [
10
+ Language('en', 'English'),
11
+ Language('zh', 'Chinese'),
12
+ Language('de', 'German'),
13
+ Language('es', 'Spanish'),
14
+ Language('ru', 'Russian'),
15
+ Language('ko', 'Korean'),
16
+ Language('fr', 'French'),
17
+ Language('ja', 'Japanese'),
18
+ Language('pt', 'Portuguese'),
19
+ Language('tr', 'Turkish'),
20
+ Language('pl', 'Polish'),
21
+ Language('ca', 'Catalan'),
22
+ Language('nl', 'Dutch'),
23
+ Language('ar', 'Arabic'),
24
+ Language('sv', 'Swedish'),
25
+ Language('it', 'Italian'),
26
+ Language('id', 'Indonesian'),
27
+ Language('hi', 'Hindi'),
28
+ Language('fi', 'Finnish'),
29
+ Language('vi', 'Vietnamese'),
30
+ Language('he', 'Hebrew'),
31
+ Language('uk', 'Ukrainian'),
32
+ Language('el', 'Greek'),
33
+ Language('ms', 'Malay'),
34
+ Language('cs', 'Czech'),
35
+ Language('ro', 'Romanian'),
36
+ Language('da', 'Danish'),
37
+ Language('hu', 'Hungarian'),
38
+ Language('ta', 'Tamil'),
39
+ Language('no', 'Norwegian'),
40
+ Language('th', 'Thai'),
41
+ Language('ur', 'Urdu'),
42
+ Language('hr', 'Croatian'),
43
+ Language('bg', 'Bulgarian'),
44
+ Language('lt', 'Lithuanian'),
45
+ Language('la', 'Latin'),
46
+ Language('mi', 'Maori'),
47
+ Language('ml', 'Malayalam'),
48
+ Language('cy', 'Welsh'),
49
+ Language('sk', 'Slovak'),
50
+ Language('te', 'Telugu'),
51
+ Language('fa', 'Persian'),
52
+ Language('lv', 'Latvian'),
53
+ Language('bn', 'Bengali'),
54
+ Language('sr', 'Serbian'),
55
+ Language('az', 'Azerbaijani'),
56
+ Language('sl', 'Slovenian'),
57
+ Language('kn', 'Kannada'),
58
+ Language('et', 'Estonian'),
59
+ Language('mk', 'Macedonian'),
60
+ Language('br', 'Breton'),
61
+ Language('eu', 'Basque'),
62
+ Language('is', 'Icelandic'),
63
+ Language('hy', 'Armenian'),
64
+ Language('ne', 'Nepali'),
65
+ Language('mn', 'Mongolian'),
66
+ Language('bs', 'Bosnian'),
67
+ Language('kk', 'Kazakh'),
68
+ Language('sq', 'Albanian'),
69
+ Language('sw', 'Swahili'),
70
+ Language('gl', 'Galician'),
71
+ Language('mr', 'Marathi'),
72
+ Language('pa', 'Punjabi'),
73
+ Language('si', 'Sinhala'),
74
+ Language('km', 'Khmer'),
75
+ Language('sn', 'Shona'),
76
+ Language('yo', 'Yoruba'),
77
+ Language('so', 'Somali'),
78
+ Language('af', 'Afrikaans'),
79
+ Language('oc', 'Occitan'),
80
+ Language('ka', 'Georgian'),
81
+ Language('be', 'Belarusian'),
82
+ Language('tg', 'Tajik'),
83
+ Language('sd', 'Sindhi'),
84
+ Language('gu', 'Gujarati'),
85
+ Language('am', 'Amharic'),
86
+ Language('yi', 'Yiddish'),
87
+ Language('lo', 'Lao'),
88
+ Language('uz', 'Uzbek'),
89
+ Language('fo', 'Faroese'),
90
+ Language('ht', 'Haitian creole'),
91
+ Language('ps', 'Pashto'),
92
+ Language('tk', 'Turkmen'),
93
+ Language('nn', 'Nynorsk'),
94
+ Language('mt', 'Maltese'),
95
+ Language('sa', 'Sanskrit'),
96
+ Language('lb', 'Luxembourgish'),
97
+ Language('my', 'Myanmar'),
98
+ Language('bo', 'Tibetan'),
99
+ Language('tl', 'Tagalog'),
100
+ Language('mg', 'Malagasy'),
101
+ Language('as', 'Assamese'),
102
+ Language('tt', 'Tatar'),
103
+ Language('haw', 'Hawaiian'),
104
+ Language('ln', 'Lingala'),
105
+ Language('ha', 'Hausa'),
106
+ Language('ba', 'Bashkir'),
107
+ Language('jw', 'Javanese'),
108
+ Language('su', 'Sundanese')
109
+ ]
110
+
111
+ _TO_LANGUAGE_CODE = {
112
+ **{language.code: language for language in LANGUAGES},
113
+ "burmese": "my",
114
+ "valencian": "ca",
115
+ "flemish": "nl",
116
+ "haitian": "ht",
117
+ "letzeburgesch": "lb",
118
+ "pushto": "ps",
119
+ "panjabi": "pa",
120
+ "moldavian": "ro",
121
+ "moldovan": "ro",
122
+ "sinhalese": "si",
123
+ "castilian": "es",
124
+ }
125
+
126
+ _FROM_LANGUAGE_NAME = {
127
+ **{language.name.lower(): language for language in LANGUAGES}
128
+ }
129
+
130
+ def get_language_from_code(language_code, default=None) -> Language:
131
+ """Return the language name from the language code."""
132
+ return _TO_LANGUAGE_CODE.get(language_code, default)
133
+
134
+ def get_language_from_name(language, default=None) -> Language:
135
+ """Return the language code from the language name."""
136
+ return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default)
137
+
138
+ def get_language_names():
139
+ """Return a list of language names."""
140
+ return [language.name for language in LANGUAGES]
141
+
142
+ if __name__ == "__main__":
143
+ # Test lookup
144
+ print(get_language_from_code('en'))
145
+ print(get_language_from_name('English'))
146
+
147
+ print(get_language_names())
src/modelCache.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModelCache:
2
+ def __init__(self):
3
+ self._cache = dict()
4
+
5
+ def get(self, model_key: str, model_factory):
6
+ result = self._cache.get(model_key)
7
+
8
+ if result is None:
9
+ result = model_factory()
10
+ self._cache[model_key] = result
11
+ return result
12
+
13
+ def clear(self):
14
+ self._cache.clear()
15
+
16
+ # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
17
+ GLOBAL_MODEL_CACHE = ModelCache()
src/segments.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
+
18
+ processed_time = 0
19
+ current_segment = None
20
+
21
+ for i in range(len(timestamps)):
22
+ next_segment = timestamps[i]
23
+
24
+ delta = next_segment['start'] - processed_time
25
+
26
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
+ # Finish the current segment
30
+ if current_segment is not None:
31
+ # Add right padding
32
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
33
+ current_segment['end'] += finish_padding
34
+ delta -= finish_padding
35
+
36
+ result.append(current_segment)
37
+
38
+ # Start a new segment
39
+ current_segment = copy.deepcopy(next_segment)
40
+
41
+ # Pad the segment
42
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
43
+ processed_time = current_segment['end']
44
+
45
+ else:
46
+ # Merge the segment
47
+ current_segment['end'] = next_segment['end']
48
+ processed_time = current_segment['end']
49
+
50
+ # Add the last segment
51
+ if current_segment is not None:
52
+ current_segment['end'] += padding_right
53
+ result.append(current_segment)
54
+
55
+ return result
src/source.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
2
+ import os
3
+ import pathlib
4
+ from typing import List
5
+ import zipfile
6
+
7
+ import ffmpeg
8
+ from more_itertools import unzip
9
+
10
+ from src.download import ExceededMaximumDuration, download_url
11
+
12
+ MAX_FILE_PREFIX_LENGTH = 17
13
+
14
+ class AudioSource:
15
+ def __init__(self, source_path, source_name = None, audio_duration = None):
16
+ self.source_path = source_path
17
+ self.source_name = source_name
18
+ self._audio_duration = audio_duration
19
+
20
+ # Load source name if not provided
21
+ if (self.source_name is None):
22
+ file_path = pathlib.Path(self.source_path)
23
+ self.source_name = file_path.name
24
+
25
+ def get_audio_duration(self):
26
+ if self._audio_duration is None:
27
+ self._audio_duration = float(ffmpeg.probe(self.source_path)["format"]["duration"])
28
+
29
+ return self._audio_duration
30
+
31
+ def get_full_name(self):
32
+ return self.source_name
33
+
34
+ def get_short_name(self, max_length: int = MAX_FILE_PREFIX_LENGTH):
35
+ file_path = pathlib.Path(self.source_name)
36
+ short_name = file_path.stem[:max_length] + file_path.suffix
37
+
38
+ return short_name
39
+
40
+ def __str__(self) -> str:
41
+ return self.source_path
42
+
43
+ class AudioSourceCollection:
44
+ def __init__(self, sources: List[AudioSource]):
45
+ self.sources = sources
46
+
47
+ def __iter__(self):
48
+ return iter(self.sources)
49
+
50
+ def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneData: str, input_audio_max_duration: float = -1) -> List[AudioSource]:
51
+ output: List[AudioSource] = []
52
+
53
+ if urlData:
54
+ # Download from YouTube. This could also be a playlist or a channel.
55
+ output.extend([ AudioSource(x) for x in download_url(urlData, input_audio_max_duration, playlistItems=None) ])
56
+ else:
57
+ # Add input files
58
+ if (multipleFiles is not None):
59
+ output.extend([ AudioSource(x.name) for x in multipleFiles ])
60
+ if (microphoneData is not None):
61
+ output.append(AudioSource(microphoneData))
62
+
63
+ total_duration = 0
64
+
65
+ # Calculate total audio length. We do this even if input_audio_max_duration
66
+ # is disabled to ensure that all the audio files are valid.
67
+ for source in output:
68
+ audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
69
+ total_duration += float(audioDuration)
70
+
71
+ # Save audio duration
72
+ source._audio_duration = float(audioDuration)
73
+
74
+ # Ensure the total duration of the audio is not too long
75
+ if input_audio_max_duration > 0:
76
+ if float(total_duration) > input_audio_max_duration:
77
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
78
+
79
+ # Return a list of audio sources
80
+ return output
src/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO
7
+ import tqdm
8
+
9
+ import urllib3
10
+
11
+
12
+ def exact_div(x, y):
13
+ assert x % y == 0
14
+ return x // y
15
+
16
+
17
+ def str2bool(string):
18
+ str2val = {"True": True, "False": False}
19
+ if string in str2val:
20
+ return str2val[string]
21
+ else:
22
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
23
+
24
+
25
+ def optional_int(string):
26
+ return None if string == "None" else int(string)
27
+
28
+
29
+ def optional_float(string):
30
+ return None if string == "None" else float(string)
31
+
32
+
33
+ def compression_ratio(text) -> float:
34
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
35
+
36
+
37
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
38
+ assert seconds >= 0, "non-negative timestamp expected"
39
+ milliseconds = round(seconds * 1000.0)
40
+
41
+ hours = milliseconds // 3_600_000
42
+ milliseconds -= hours * 3_600_000
43
+
44
+ minutes = milliseconds // 60_000
45
+ milliseconds -= minutes * 60_000
46
+
47
+ seconds = milliseconds // 1_000
48
+ milliseconds -= seconds * 1_000
49
+
50
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
51
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
52
+
53
+
54
+ def write_txt(transcript: Iterator[dict], file: TextIO):
55
+ for segment in transcript:
56
+ print(segment['text'].strip(), file=file, flush=True)
57
+
58
+
59
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
60
+ print("WEBVTT\n", file=file)
61
+ for segment in transcript:
62
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
63
+
64
+ print(
65
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
66
+ f"{text}\n",
67
+ file=file,
68
+ flush=True,
69
+ )
70
+
71
+
72
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
73
+ """
74
+ Write a transcript to a file in SRT format.
75
+ Example usage:
76
+ from pathlib import Path
77
+ from whisper.utils import write_srt
78
+ result = transcribe(model, audio_path, temperature=temperature, **args)
79
+ # save SRT
80
+ audio_basename = Path(audio_path).stem
81
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
82
+ write_srt(result["segments"], file=srt)
83
+ """
84
+ for i, segment in enumerate(transcript, start=1):
85
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
86
+
87
+ # write srt lines
88
+ print(
89
+ f"{i}\n"
90
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
91
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
92
+ f"{text}\n",
93
+ file=file,
94
+ flush=True,
95
+ )
96
+
97
+ def process_text(text: str, maxLineWidth=None):
98
+ if (maxLineWidth is None or maxLineWidth < 0):
99
+ return text
100
+
101
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
102
+ return '\n'.join(lines)
103
+
104
+ def slugify(value, allow_unicode=False):
105
+ """
106
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
107
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
108
+ dashes to single dashes. Remove characters that aren't alphanumerics,
109
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
110
+ trailing whitespace, dashes, and underscores.
111
+ """
112
+ value = str(value)
113
+ if allow_unicode:
114
+ value = unicodedata.normalize('NFKC', value)
115
+ else:
116
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
117
+ value = re.sub(r'[^\w\s-]', '', value.lower())
118
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
119
+
120
+ def download_file(url: str, destination: str):
121
+ with urllib3.request.urlopen(url) as source, open(destination, "wb") as output:
122
+ with tqdm(
123
+ total=int(source.info().get("Content-Length")),
124
+ ncols=80,
125
+ unit="iB",
126
+ unit_scale=True,
127
+ unit_divisor=1024,
128
+ ) as loop:
129
+ while True:
130
+ buffer = source.read(8192)
131
+ if not buffer:
132
+ break
133
+
134
+ output.write(buffer)
135
+ loop.update(len(buffer))
src/vad.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
+ import time
4
+
5
+ from typing import Any, Deque, Iterator, List, Dict
6
+
7
+ from pprint import pprint
8
+ from src.hooks.progressListener import ProgressListener
9
+ from src.hooks.subTaskProgressListener import SubTaskProgressListener
10
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
11
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
12
+
13
+ from src.segments import merge_timestamps
14
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
15
+
16
+ # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
17
+ try:
18
+ import tensorflow as tf
19
+ except ModuleNotFoundError:
20
+ # Error handling
21
+ pass
22
+
23
+ import torch
24
+
25
+ import ffmpeg
26
+ import numpy as np
27
+
28
+ from src.utils import format_timestamp
29
+ from enum import Enum
30
+
31
+ class NonSpeechStrategy(Enum):
32
+ """
33
+ Ignore non-speech frames segments.
34
+ """
35
+ SKIP = 1
36
+ """
37
+ Just treat non-speech segments as speech.
38
+ """
39
+ CREATE_SEGMENT = 2
40
+ """
41
+ Expand speech segments into subsequent non-speech segments.
42
+ """
43
+ EXPAND_SEGMENT = 3
44
+
45
+ # Defaults for Silero
46
+ SPEECH_TRESHOLD = 0.3
47
+
48
+ # Minimum size of segments to process
49
+ MIN_SEGMENT_DURATION = 1
50
+
51
+ # The maximum time for texts from old segments to be used in the next segment
52
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
53
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
54
+
55
+ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
56
+
57
+ class TranscriptionConfig(ABC):
58
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
59
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
60
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
61
+ self.non_speech_strategy = non_speech_strategy
62
+ self.segment_padding_left = segment_padding_left
63
+ self.segment_padding_right = segment_padding_right
64
+ self.max_silent_period = max_silent_period
65
+ self.max_merge_size = max_merge_size
66
+ self.max_prompt_window = max_prompt_window
67
+ self.initial_segment_index = initial_segment_index
68
+
69
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
70
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
71
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
72
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
73
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
74
+ self.periodic_duration = periodic_duration
75
+
76
+ class AbstractTranscription(ABC):
77
+ def __init__(self, sampling_rate: int = 16000):
78
+ self.sampling_rate = sampling_rate
79
+
80
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
81
+ return load_audio(str, self.sampling_rate, start_time, duration)
82
+
83
+ def is_transcribe_timestamps_fast(self):
84
+ """
85
+ Determine if get_transcribe_timestamps is fast enough to not need parallelization.
86
+ """
87
+ return False
88
+
89
+ @abstractmethod
90
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
91
+ """
92
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
93
+
94
+ Parameters
95
+ ----------
96
+ audio: str
97
+ The audio file.
98
+ config: TranscriptionConfig
99
+ The transcription configuration.
100
+
101
+ Returns
102
+ -------
103
+ A list of start and end timestamps, in fractional seconds.
104
+ """
105
+ return
106
+
107
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
108
+ """
109
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method,
110
+ after merging the given segments using the specified configuration.
111
+
112
+ Parameters
113
+ ----------
114
+ audio: str
115
+ The audio file.
116
+ config: TranscriptionConfig
117
+ The transcription configuration.
118
+
119
+ Returns
120
+ -------
121
+ A list of start and end timestamps, in fractional seconds.
122
+ """
123
+ merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
124
+ config.segment_padding_left, config.segment_padding_right)
125
+
126
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
127
+ # Expand segments to include the gaps between them
128
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
129
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
130
+ merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
131
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
132
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
133
+ merged = self.expand_gaps(merged, total_duration=total_duration)
134
+ else:
135
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
136
+
137
+ print("Transcribing non-speech:")
138
+ pprint(merged)
139
+ return merged
140
+
141
+ def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
142
+ progressListener: ProgressListener = None):
143
+ """
144
+ Transcribe the given audo file.
145
+
146
+ Parameters
147
+ ----------
148
+ audio: str
149
+ The audio file.
150
+ whisperCallable: WhisperCallback
151
+ A callback object to call to transcribe each segment.
152
+
153
+ Returns
154
+ -------
155
+ A list of start and end timestamps, in fractional seconds.
156
+ """
157
+
158
+ try:
159
+ max_audio_duration = self.get_audio_duration(audio, config)
160
+ timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
161
+
162
+ # Get speech timestamps from full audio file
163
+ merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
164
+
165
+ # A deque of transcribed segments that is passed to the next segment as a prompt
166
+ prompt_window = deque()
167
+
168
+ print("Processing timestamps:")
169
+ pprint(merged)
170
+
171
+ result = {
172
+ 'text': "",
173
+ 'segments': [],
174
+ 'language': ""
175
+ }
176
+ languageCounter = Counter()
177
+ detected_language = None
178
+
179
+ segment_index = config.initial_segment_index
180
+
181
+ # Calculate progress
182
+ progress_start_offset = merged[0]['start'] if len(merged) > 0 else 0
183
+ progress_total_duration = sum([segment['end'] - segment['start'] for segment in merged])
184
+
185
+ # For each time segment, run whisper
186
+ for segment in merged:
187
+ segment_index += 1
188
+ segment_start = segment['start']
189
+ segment_end = segment['end']
190
+ segment_expand_amount = segment.get('expand_amount', 0)
191
+ segment_gap = segment.get('gap', False)
192
+
193
+ segment_duration = segment_end - segment_start
194
+
195
+ if segment_duration < MIN_SEGMENT_DURATION:
196
+ continue
197
+
198
+ # Audio to run on Whisper
199
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
200
+ # Previous segments to use as a prompt
201
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
202
+
203
+ # Detected language
204
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
205
+
206
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
207
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
208
+
209
+ perf_start_time = time.perf_counter()
210
+
211
+ scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=progress_total_duration,
212
+ sub_task_start=segment_start - progress_start_offset, sub_task_total=segment_duration)
213
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
214
+
215
+ perf_end_time = time.perf_counter()
216
+ print("Whisper took {} seconds".format(perf_end_time - perf_start_time))
217
+
218
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
219
+
220
+ # Propagate expand amount to the segments
221
+ if (segment_expand_amount > 0):
222
+ segment_without_expansion = segment_duration - segment_expand_amount
223
+
224
+ for adjusted_segment in adjusted_segments:
225
+ adjusted_segment_end = adjusted_segment['end']
226
+
227
+ # Add expand amount if the segment got expanded
228
+ if (adjusted_segment_end > segment_without_expansion):
229
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
230
+
231
+ # Append to output
232
+ result['text'] += segment_result['text']
233
+ result['segments'].extend(adjusted_segments)
234
+
235
+ # Increment detected language
236
+ if not segment_gap:
237
+ languageCounter[segment_result['language']] += 1
238
+
239
+ # Update prompt window
240
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
241
+
242
+ if detected_language is not None:
243
+ result['language'] = detected_language
244
+ finally:
245
+ # Notify progress listener that we are done
246
+ if progressListener is not None:
247
+ progressListener.on_finished()
248
+ return result
249
+
250
+ def get_audio_duration(self, audio: str, config: TranscriptionConfig):
251
+ return get_audio_duration(audio)
252
+
253
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
254
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
255
+ # Add segments to the current prompt window (unless it is a speech gap)
256
+ if not segment_gap:
257
+ for segment in adjusted_segments:
258
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
259
+ prompt_window.append(segment)
260
+
261
+ while (len(prompt_window) > 0):
262
+ first_end_time = prompt_window[0].get('end', 0)
263
+ # Time expanded in the segments should be discounted from the prompt window
264
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
265
+
266
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
267
+ prompt_window.popleft()
268
+ else:
269
+ break
270
+
271
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
272
+ result = []
273
+ last_end_time = 0
274
+
275
+ for segment in segments:
276
+ segment_start = float(segment['start'])
277
+ segment_end = float(segment['end'])
278
+
279
+ if (last_end_time != segment_start):
280
+ delta = segment_start - last_end_time
281
+
282
+ if (min_gap_length is None or delta >= min_gap_length):
283
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
284
+
285
+ last_end_time = segment_end
286
+ result.append(segment)
287
+
288
+ # Also include total duration if specified
289
+ if (total_duration is not None and last_end_time < total_duration):
290
+ delta = total_duration - segment_start
291
+
292
+ if (min_gap_length is None or delta >= min_gap_length):
293
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
294
+
295
+ return result
296
+
297
+ # Expand the end time of each segment to the start of the next segment
298
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
299
+ result = []
300
+
301
+ if len(segments) == 0:
302
+ return result
303
+
304
+ # Add gap at the beginning if needed
305
+ if (segments[0]['start'] > 0):
306
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
307
+
308
+ for i in range(len(segments) - 1):
309
+ current_segment = segments[i]
310
+ next_segment = segments[i + 1]
311
+
312
+ delta = next_segment['start'] - current_segment['end']
313
+
314
+ # Expand if the gap actually exists
315
+ if (delta >= 0):
316
+ current_segment = current_segment.copy()
317
+ current_segment['expand_amount'] = delta
318
+ current_segment['end'] = next_segment['start']
319
+
320
+ result.append(current_segment)
321
+
322
+ # Add last segment
323
+ last_segment = segments[-1]
324
+ result.append(last_segment)
325
+
326
+ # Also include total duration if specified
327
+ if (total_duration is not None):
328
+ last_segment = result[-1]
329
+
330
+ if (last_segment['end'] < total_duration):
331
+ last_segment = last_segment.copy()
332
+ last_segment['end'] = total_duration
333
+ result[-1] = last_segment
334
+
335
+ return result
336
+
337
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
338
+ result = []
339
+
340
+ if len(segments) == 0:
341
+ return result
342
+
343
+ # Add gap at the beginning if needed
344
+ if (segments[0]['start'] > 0):
345
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
346
+
347
+ for i in range(len(segments) - 1):
348
+ expanded = False
349
+ current_segment = segments[i]
350
+ next_segment = segments[i + 1]
351
+
352
+ delta = next_segment['start'] - current_segment['end']
353
+
354
+ if (max_expand_size is not None and delta <= max_expand_size):
355
+ # Just expand the current segment
356
+ current_segment = current_segment.copy()
357
+ current_segment['expand_amount'] = delta
358
+ current_segment['end'] = next_segment['start']
359
+ expanded = True
360
+
361
+ result.append(current_segment)
362
+
363
+ # Add a gap to the next segment if needed
364
+ if (delta >= 0 and not expanded):
365
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
366
+
367
+ # Add last segment
368
+ last_segment = segments[-1]
369
+ result.append(last_segment)
370
+
371
+ # Also include total duration if specified
372
+ if (total_duration is not None):
373
+ last_segment = result[-1]
374
+
375
+ delta = total_duration - last_segment['end']
376
+
377
+ if (delta > 0):
378
+ if (max_expand_size is not None and delta <= max_expand_size):
379
+ # Expand the last segment
380
+ last_segment = last_segment.copy()
381
+ last_segment['expand_amount'] = delta
382
+ last_segment['end'] = total_duration
383
+ result[-1] = last_segment
384
+ else:
385
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
386
+
387
+ return result
388
+
389
+ def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
390
+ result = []
391
+
392
+ for segment in segments:
393
+ segment_start = float(segment['start'])
394
+ segment_end = float(segment['end'])
395
+
396
+ # Filter segments?
397
+ if (max_source_time is not None):
398
+ if (segment_start > max_source_time):
399
+ continue
400
+ segment_end = min(max_source_time, segment_end)
401
+
402
+ new_segment = segment.copy()
403
+
404
+ # Add to start and end
405
+ new_segment['start'] = segment_start + adjust_seconds
406
+ new_segment['end'] = segment_end + adjust_seconds
407
+ result.append(new_segment)
408
+ return result
409
+
410
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
411
+ result = []
412
+
413
+ for entry in timestamps:
414
+ start = entry['start']
415
+ end = entry['end']
416
+
417
+ result.append({
418
+ 'start': start * factor,
419
+ 'end': end * factor
420
+ })
421
+ return result
422
+
423
+
424
+ class VadSileroTranscription(AbstractTranscription):
425
+ def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
426
+ super().__init__(sampling_rate=sampling_rate)
427
+ self.model = None
428
+ self.cache = cache
429
+ self._initialize_model()
430
+
431
+ def _initialize_model(self):
432
+ if (self.cache is not None):
433
+ model_key = "VadSileroTranscription"
434
+ self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
435
+ print("Loaded Silerio model from cache.")
436
+ else:
437
+ self.model, self.get_speech_timestamps = self._create_model()
438
+ print("Created Silerio model")
439
+
440
+ def _create_model(self):
441
+ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
442
+
443
+ # Silero does not benefit from multi-threading
444
+ torch.set_num_threads(1) # JIT
445
+ (get_speech_timestamps, _, _, _, _) = utils
446
+
447
+ return model, get_speech_timestamps
448
+
449
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
450
+ result = []
451
+
452
+ print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
453
+ perf_start_time = time.perf_counter()
454
+
455
+ # Divide procesisng of audio into chunks
456
+ chunk_start = start_time
457
+
458
+ while (chunk_start < end_time):
459
+ chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
460
+
461
+ print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
462
+ wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
463
+
464
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
465
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
466
+ adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
467
+
468
+ #pprint(adjusted)
469
+
470
+ result.extend(adjusted)
471
+ chunk_start += chunk_duration
472
+
473
+ perf_end_time = time.perf_counter()
474
+ print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
475
+
476
+ return result
477
+
478
+ def __getstate__(self):
479
+ # We only need the sampling rate
480
+ return { 'sampling_rate': self.sampling_rate }
481
+
482
+ def __setstate__(self, state):
483
+ self.sampling_rate = state['sampling_rate']
484
+ self.model = None
485
+ # Use the global cache
486
+ self.cache = GLOBAL_MODEL_CACHE
487
+ self._initialize_model()
488
+
489
+ # A very simple VAD that just marks every N seconds as speech
490
+ class VadPeriodicTranscription(AbstractTranscription):
491
+ def __init__(self, sampling_rate: int = 16000):
492
+ super().__init__(sampling_rate=sampling_rate)
493
+
494
+ def is_transcribe_timestamps_fast(self):
495
+ # This is a very fast VAD - no need to parallelize it
496
+ return True
497
+
498
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
499
+ result = []
500
+
501
+ # Generate a timestamp every N seconds
502
+ start_timestamp = start_time
503
+
504
+ while (start_timestamp < end_time):
505
+ end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
506
+ segment_duration = end_timestamp - start_timestamp
507
+
508
+ # Minimum duration is 1 second
509
+ if (segment_duration >= 1):
510
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
511
+
512
+ start_timestamp = end_timestamp
513
+
514
+ return result
515
+
516
+ def get_audio_duration(file: str):
517
+ return float(ffmpeg.probe(file)["format"]["duration"])
518
+
519
+ def load_audio(file: str, sample_rate: int = 16000,
520
+ start_time: str = None, duration: str = None):
521
+ """
522
+ Open an audio file and read as mono waveform, resampling as necessary
523
+
524
+ Parameters
525
+ ----------
526
+ file: str
527
+ The audio file to open
528
+
529
+ sr: int
530
+ The sample rate to resample the audio if necessary
531
+
532
+ start_time: str
533
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
534
+
535
+ duration: str
536
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
537
+
538
+ Returns
539
+ -------
540
+ A NumPy array containing the audio waveform, in float32 dtype.
541
+ """
542
+ try:
543
+ inputArgs = {'threads': 0}
544
+
545
+ if (start_time is not None):
546
+ inputArgs['ss'] = start_time
547
+ if (duration is not None):
548
+ inputArgs['t'] = duration
549
+
550
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
551
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
552
+ out, _ = (
553
+ ffmpeg.input(file, **inputArgs)
554
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
555
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
556
+ )
557
+ except ffmpeg.Error as e:
558
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
559
+
560
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
src/vadParallel.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ from queue import Empty
3
+ import threading
4
+ import time
5
+ from src.hooks.progressListener import ProgressListener
6
+ from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
7
+
8
+ from multiprocessing import Pool, Queue
9
+
10
+ from typing import Any, Dict, List, Union
11
+ import os
12
+
13
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
14
+
15
+ class _ProgressListenerToQueue(ProgressListener):
16
+ def __init__(self, progress_queue: Queue):
17
+ self.progress_queue = progress_queue
18
+ self.progress_total = 0
19
+ self.prev_progress = 0
20
+
21
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
22
+ delta = current - self.prev_progress
23
+ self.prev_progress = current
24
+ self.progress_total = total
25
+ self.progress_queue.put(delta)
26
+
27
+ def on_finished(self):
28
+ if self.progress_total > self.prev_progress:
29
+ delta = self.progress_total - self.prev_progress
30
+ self.progress_queue.put(delta)
31
+ self.prev_progress = self.progress_total
32
+
33
+ class ParallelContext:
34
+ def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
35
+ self.num_processes = num_processes
36
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
37
+ self.lock = threading.Lock()
38
+
39
+ self.ref_count = 0
40
+ self.pool = None
41
+ self.cleanup_timer = None
42
+
43
+ def get_pool(self):
44
+ # Initialize pool lazily
45
+ if (self.pool is None):
46
+ context = multiprocessing.get_context('spawn')
47
+ self.pool = context.Pool(self.num_processes)
48
+
49
+ self.ref_count = self.ref_count + 1
50
+
51
+ if (self.auto_cleanup_timeout_seconds is not None):
52
+ self._stop_auto_cleanup()
53
+
54
+ return self.pool
55
+
56
+ def return_pool(self, pool):
57
+ if (self.pool == pool and self.ref_count > 0):
58
+ self.ref_count = self.ref_count - 1
59
+
60
+ if (self.ref_count == 0):
61
+ if (self.auto_cleanup_timeout_seconds is not None):
62
+ self._start_auto_cleanup()
63
+
64
+ def _start_auto_cleanup(self):
65
+ if (self.cleanup_timer is not None):
66
+ self.cleanup_timer.cancel()
67
+ self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
68
+ self.cleanup_timer.start()
69
+
70
+ print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
71
+
72
+ def _stop_auto_cleanup(self):
73
+ if (self.cleanup_timer is not None):
74
+ self.cleanup_timer.cancel()
75
+ self.cleanup_timer = None
76
+
77
+ print("Stopped auto cleanup of pool")
78
+
79
+ def _execute_cleanup(self):
80
+ print("Executing cleanup of pool")
81
+
82
+ if (self.ref_count == 0):
83
+ self.close()
84
+
85
+ def close(self):
86
+ self._stop_auto_cleanup()
87
+
88
+ if (self.pool is not None):
89
+ print("Closing pool of " + str(self.num_processes) + " processes")
90
+ self.pool.close()
91
+ self.pool.join()
92
+ self.pool = None
93
+
94
+ class ParallelTranscriptionConfig(TranscriptionConfig):
95
+ def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
96
+ super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
97
+ self.device_id = device_id
98
+ self.override_timestamps = override_timestamps
99
+
100
+ class ParallelTranscription(AbstractTranscription):
101
+ # Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
102
+ # into smaller segments than 2 minute (min 6 seconds per CPU core)
103
+ MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
104
+
105
+ def __init__(self, sampling_rate: int = 16000):
106
+ super().__init__(sampling_rate=sampling_rate)
107
+
108
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
109
+ cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
110
+ progress_listener: ProgressListener = None):
111
+ total_duration = get_audio_duration(audio)
112
+
113
+ # First, get the timestamps for the original audio
114
+ if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()):
115
+ merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
116
+ else:
117
+ timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
118
+ merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
119
+
120
+ # We must make sure the whisper model is downloaded
121
+ if (len(gpu_devices) > 1):
122
+ whisperCallable.model_container.ensure_downloaded()
123
+
124
+ # Split into a list for each device
125
+ # TODO: Split by time instead of by number of chunks
126
+ merged_split = list(self._split(merged, len(gpu_devices)))
127
+
128
+ # Parameters that will be passed to the transcribe function
129
+ parameters = []
130
+ segment_index = config.initial_segment_index
131
+
132
+ processing_manager = multiprocessing.Manager()
133
+ progress_queue = processing_manager.Queue()
134
+
135
+ for i in range(len(gpu_devices)):
136
+ # Note that device_segment_list can be empty. But we will still create a process for it,
137
+ # as otherwise we run the risk of assigning the same device to multiple processes.
138
+ device_segment_list = list(merged_split[i]) if i < len(merged_split) else []
139
+ device_id = gpu_devices[i]
140
+
141
+ print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
142
+
143
+ # Create a new config with the given device ID
144
+ device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
145
+ segment_index += len(device_segment_list)
146
+
147
+ progress_listener_to_queue = _ProgressListenerToQueue(progress_queue)
148
+ parameters.append([audio, whisperCallable, device_config, progress_listener_to_queue]);
149
+
150
+ merged = {
151
+ 'text': '',
152
+ 'segments': [],
153
+ 'language': None
154
+ }
155
+
156
+ created_context = False
157
+
158
+ perf_start_gpu = time.perf_counter()
159
+
160
+ # Spawn a separate process for each device
161
+ try:
162
+ if (gpu_parallel_context is None):
163
+ gpu_parallel_context = ParallelContext(len(gpu_devices))
164
+ created_context = True
165
+
166
+ # Get a pool of processes
167
+ pool = gpu_parallel_context.get_pool()
168
+
169
+ # Run the transcription in parallel
170
+ results_async = pool.starmap_async(self.transcribe, parameters)
171
+ total_progress = 0
172
+
173
+ while not results_async.ready():
174
+ try:
175
+ delta = progress_queue.get(timeout=5) # Set a timeout of 5 seconds
176
+ except Empty:
177
+ continue
178
+
179
+ total_progress += delta
180
+ if progress_listener is not None:
181
+ progress_listener.on_progress(total_progress, total_duration)
182
+
183
+ results = results_async.get()
184
+
185
+ # Call the finished callback
186
+ if progress_listener is not None:
187
+ progress_listener.on_finished()
188
+
189
+ for result in results:
190
+ # Merge the results
191
+ if (result['text'] is not None):
192
+ merged['text'] += result['text']
193
+ if (result['segments'] is not None):
194
+ merged['segments'].extend(result['segments'])
195
+ if (result['language'] is not None):
196
+ merged['language'] = result['language']
197
+
198
+ finally:
199
+ # Return the pool to the context
200
+ if (gpu_parallel_context is not None):
201
+ gpu_parallel_context.return_pool(pool)
202
+ # Always close the context if we created it
203
+ if (created_context):
204
+ gpu_parallel_context.close()
205
+
206
+ perf_end_gpu = time.perf_counter()
207
+ print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
208
+
209
+ return merged
210
+
211
+ def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
212
+ cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
213
+ parameters = []
214
+
215
+ chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
216
+ chunk_start = 0
217
+ cpu_device_id = 0
218
+
219
+ perf_start_time = time.perf_counter()
220
+
221
+ # Create chunks that will be processed on the CPU
222
+ while (chunk_start < total_duration):
223
+ chunk_end = min(chunk_start + chunk_size, total_duration)
224
+
225
+ if (chunk_end - chunk_start < 1):
226
+ # No need to process chunks that are less than 1 second
227
+ break
228
+
229
+ print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
230
+ str(chunk_end) + " on CPU device " + str(cpu_device_id))
231
+ parameters.append([audio, config, chunk_start, chunk_end]);
232
+
233
+ cpu_device_id += 1
234
+ chunk_start = chunk_end
235
+
236
+ created_context = False
237
+
238
+ # Spawn a separate process for each device
239
+ try:
240
+ if (cpu_parallel_context is None):
241
+ cpu_parallel_context = ParallelContext(cpu_device_count)
242
+ created_context = True
243
+
244
+ # Get a pool of processes
245
+ pool = cpu_parallel_context.get_pool()
246
+
247
+ # Run the transcription in parallel. Note that transcription must be picklable.
248
+ results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
249
+
250
+ timestamps = []
251
+
252
+ # Flatten the results
253
+ for result in results:
254
+ timestamps.extend(result)
255
+
256
+ merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
257
+
258
+ perf_end_time = time.perf_counter()
259
+ print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
260
+ return merged
261
+
262
+ finally:
263
+ # Return the pool to the context
264
+ if (cpu_parallel_context is not None):
265
+ cpu_parallel_context.return_pool(pool)
266
+ # Always close the context if we created it
267
+ if (created_context):
268
+ cpu_parallel_context.close()
269
+
270
+ def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
271
+ return []
272
+
273
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
274
+ # Override timestamps that will be processed
275
+ if (config.override_timestamps is not None):
276
+ print("(get_merged_timestamps) Using override timestamps of size " + str(len(config.override_timestamps)))
277
+ return config.override_timestamps
278
+ return super().get_merged_timestamps(timestamps, config, total_duration)
279
+
280
+ def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: ParallelTranscriptionConfig,
281
+ progressListener: ProgressListener = None):
282
+ # Override device ID the first time
283
+ if (os.environ.get("INITIALIZED", None) is None):
284
+ os.environ["INITIALIZED"] = "1"
285
+
286
+ # Note that this may be None if the user didn't specify a device. In that case, Whisper will
287
+ # just use the default GPU device.
288
+ if (config.device_id is not None):
289
+ print("Using device " + config.device_id)
290
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
291
+
292
+ return super().transcribe(audio, whisperCallable, config, progressListener)
293
+
294
+ def _split(self, a, n):
295
+ """Split a list into n approximately equal parts."""
296
+ k, m = divmod(len(a), n)
297
+ return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
298
+
src/whisper/abstractWhisperContainer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List
3
+ from src.config import ModelConfig, VadInitialPromptMode
4
+
5
+ from src.hooks.progressListener import ProgressListener
6
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
7
+
8
+ class AbstractWhisperCallback:
9
+ @abc.abstractmethod
10
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
11
+ """
12
+ Peform the transcription of the given audio file or data.
13
+
14
+ Parameters
15
+ ----------
16
+ audio: Union[str, np.ndarray, torch.Tensor]
17
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
18
+ segment_index: int
19
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
20
+ task: str
21
+ The task - either translate or transcribe.
22
+ progress_listener: ProgressListener
23
+ A callback to receive progress updates.
24
+ """
25
+ raise NotImplementedError()
26
+
27
+ def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode,
28
+ prompt: str, segment_index: int):
29
+ if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
30
+ return self._concat_prompt(initial_prompt, prompt)
31
+ elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
32
+ return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
33
+ else:
34
+ raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")
35
+
36
+ def _concat_prompt(self, prompt1, prompt2):
37
+ if (prompt1 is None):
38
+ return prompt2
39
+ elif (prompt2 is None):
40
+ return prompt1
41
+ else:
42
+ return prompt1 + " " + prompt2
43
+
44
+ class AbstractWhisperContainer:
45
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
46
+ download_root: str = None,
47
+ cache: ModelCache = None, models: List[ModelConfig] = []):
48
+ self.model_name = model_name
49
+ self.device = device
50
+ self.compute_type = compute_type
51
+ self.download_root = download_root
52
+ self.cache = cache
53
+
54
+ # Will be created on demand
55
+ self.model = None
56
+
57
+ # List of known models
58
+ self.models = models
59
+
60
+ def get_model(self):
61
+ if self.model is None:
62
+
63
+ if (self.cache is None):
64
+ self.model = self._create_model()
65
+ else:
66
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
67
+ self.model = self.cache.get(model_key, self._create_model)
68
+ return self.model
69
+
70
+ @abc.abstractmethod
71
+ def _create_model(self):
72
+ raise NotImplementedError()
73
+
74
+ def ensure_downloaded(self):
75
+ pass
76
+
77
+ @abc.abstractmethod
78
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
79
+ initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
80
+ **decodeOptions: dict) -> AbstractWhisperCallback:
81
+ """
82
+ Create a WhisperCallback object that can be used to transcript audio files.
83
+
84
+ Parameters
85
+ ----------
86
+ language: str
87
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
88
+ task: str
89
+ The task - either translate or transcribe.
90
+ initial_prompt: str
91
+ The initial prompt to use for the transcription.
92
+ initial_prompt_mode: VadInitialPromptMode
93
+ The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
94
+ If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
95
+ decodeOptions: dict
96
+ Additional options to pass to the decoder. Must be pickleable.
97
+
98
+ Returns
99
+ -------
100
+ A WhisperCallback object.
101
+ """
102
+ raise NotImplementedError()
103
+
104
+ # This is required for multiprocessing
105
+ def __getstate__(self):
106
+ return {
107
+ "model_name": self.model_name,
108
+ "device": self.device,
109
+ "download_root": self.download_root,
110
+ "models": self.models,
111
+ "compute_type": self.compute_type
112
+ }
113
+
114
+ def __setstate__(self, state):
115
+ self.model_name = state["model_name"]
116
+ self.device = state["device"]
117
+ self.download_root = state["download_root"]
118
+ self.models = state["models"]
119
+ self.compute_type = state["compute_type"]
120
+ self.model = None
121
+ # Depickled objects must use the global cache
122
+ self.cache = GLOBAL_MODEL_CACHE
src/whisper/fasterWhisperContainer.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+ from faster_whisper import WhisperModel, download_model
5
+ from src.config import ModelConfig, VadInitialPromptMode
6
+ from src.hooks.progressListener import ProgressListener
7
+ from src.languages import get_language_from_name
8
+ from src.modelCache import ModelCache
9
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
10
+ from src.utils import format_timestamp
11
+
12
+ class FasterWhisperContainer(AbstractWhisperContainer):
13
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
14
+ download_root: str = None,
15
+ cache: ModelCache = None, models: List[ModelConfig] = []):
16
+ super().__init__(model_name, device, compute_type, download_root, cache, models)
17
+
18
+ def ensure_downloaded(self):
19
+ """
20
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
21
+ passing the container to a subprocess.
22
+ """
23
+ model_config = self._get_model_config()
24
+
25
+ if os.path.isdir(model_config.url):
26
+ model_config.path = model_config.url
27
+ else:
28
+ model_config.path = download_model(model_config.url, output_dir=self.download_root)
29
+
30
+ def _get_model_config(self) -> ModelConfig:
31
+ """
32
+ Get the model configuration for the model.
33
+ """
34
+ for model in self.models:
35
+ if model.name == self.model_name:
36
+ return model
37
+ return None
38
+
39
+ def _create_model(self):
40
+ print("Loading faster whisper model " + self.model_name + " for device " + str(self.device))
41
+ model_config = self._get_model_config()
42
+
43
+ if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
44
+ raise Exception("FasterWhisperContainer does not yet support Whisper models. Use ct2-transformers-converter to convert the model to a faster-whisper model.")
45
+
46
+ device = self.device
47
+
48
+ if (device is None):
49
+ device = "auto"
50
+
51
+ model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
52
+ return model
53
+
54
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
55
+ initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
56
+ **decodeOptions: dict) -> AbstractWhisperCallback:
57
+ """
58
+ Create a WhisperCallback object that can be used to transcript audio files.
59
+
60
+ Parameters
61
+ ----------
62
+ language: str
63
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
64
+ task: str
65
+ The task - either translate or transcribe.
66
+ initial_prompt: str
67
+ The initial prompt to use for the transcription.
68
+ initial_prompt_mode: VadInitialPromptMode
69
+ The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
70
+ If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
71
+ decodeOptions: dict
72
+ Additional options to pass to the decoder. Must be pickleable.
73
+
74
+ Returns
75
+ -------
76
+ A WhisperCallback object.
77
+ """
78
+ return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode, **decodeOptions)
79
+
80
+ class FasterWhisperCallback(AbstractWhisperCallback):
81
+ def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
82
+ initial_prompt: str = None, initial_prompt_mode: VadInitialPromptMode=VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
83
+ **decodeOptions: dict):
84
+ self.model_container = model_container
85
+ self.language = language
86
+ self.task = task
87
+ self.initial_prompt = initial_prompt
88
+ self.initial_prompt_mode = initial_prompt_mode
89
+ self.decodeOptions = decodeOptions
90
+
91
+ self._printed_warning = False
92
+
93
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
94
+ """
95
+ Peform the transcription of the given audio file or data.
96
+
97
+ Parameters
98
+ ----------
99
+ audio: Union[str, np.ndarray, torch.Tensor]
100
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
101
+ segment_index: int
102
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
103
+ task: str
104
+ The task - either translate or transcribe.
105
+ progress_listener: ProgressListener
106
+ A callback to receive progress updates.
107
+ """
108
+ model: WhisperModel = self.model_container.get_model()
109
+ language_code = self._lookup_language_code(self.language) if self.language else None
110
+
111
+ # Copy decode options and remove options that are not supported by faster-whisper
112
+ decodeOptions = self.decodeOptions.copy()
113
+ verbose = decodeOptions.pop("verbose", None)
114
+
115
+ logprob_threshold = decodeOptions.pop("logprob_threshold", None)
116
+
117
+ patience = decodeOptions.pop("patience", None)
118
+ length_penalty = decodeOptions.pop("length_penalty", None)
119
+ suppress_tokens = decodeOptions.pop("suppress_tokens", None)
120
+
121
+ if (decodeOptions.pop("fp16", None) is not None):
122
+ if not self._printed_warning:
123
+ print("WARNING: fp16 option is ignored by faster-whisper - use compute_type instead.")
124
+ self._printed_warning = True
125
+
126
+ # Fix up decode options
127
+ if (logprob_threshold is not None):
128
+ decodeOptions["log_prob_threshold"] = logprob_threshold
129
+
130
+ decodeOptions["patience"] = float(patience) if patience is not None else 1.0
131
+ decodeOptions["length_penalty"] = float(length_penalty) if length_penalty is not None else 1.0
132
+
133
+ # See if supress_tokens is a string - if so, convert it to a list of ints
134
+ decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
135
+
136
+ initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
137
+
138
+ segments_generator, info = model.transcribe(audio, \
139
+ language=language_code if language_code else detected_language, task=self.task, \
140
+ initial_prompt=initial_prompt, \
141
+ **decodeOptions
142
+ )
143
+
144
+ segments = []
145
+
146
+ for segment in segments_generator:
147
+ segments.append(segment)
148
+
149
+ if progress_listener is not None:
150
+ progress_listener.on_progress(segment.end, info.duration)
151
+ if verbose:
152
+ print("[{}->{}] {}".format(format_timestamp(segment.start, True), format_timestamp(segment.end, True),
153
+ segment.text))
154
+
155
+ text = " ".join([segment.text for segment in segments])
156
+
157
+ # Convert the segments to a format that is easier to serialize
158
+ whisper_segments = [{
159
+ "text": segment.text,
160
+ "start": segment.start,
161
+ "end": segment.end,
162
+
163
+ # Extra fields added by faster-whisper
164
+ "words": [{
165
+ "start": word.start,
166
+ "end": word.end,
167
+ "word": word.word,
168
+ "probability": word.probability
169
+ } for word in (segment.words if segment.words is not None else []) ]
170
+ } for segment in segments]
171
+
172
+ result = {
173
+ "segments": whisper_segments,
174
+ "text": text,
175
+ "language": info.language if info else None,
176
+
177
+ # Extra fields added by faster-whisper
178
+ "language_probability": info.language_probability if info else None,
179
+ "duration": info.duration if info else None
180
+ }
181
+
182
+ if progress_listener is not None:
183
+ progress_listener.on_finished()
184
+ return result
185
+
186
+ def _split_suppress_tokens(self, suppress_tokens: Union[str, List[int]]):
187
+ if (suppress_tokens is None):
188
+ return None
189
+ if (isinstance(suppress_tokens, list)):
190
+ return suppress_tokens
191
+
192
+ return [int(token) for token in suppress_tokens.split(",")]
193
+
194
+ def _lookup_language_code(self, language: str):
195
+ language = get_language_from_name(language)
196
+
197
+ if language is None:
198
+ raise ValueError("Invalid language: " + language)
199
+
200
+ return language.code
src/whisper/whisperContainer.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External programs
2
+ import abc
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from urllib.parse import urlparse
7
+ import torch
8
+ import urllib3
9
+ from src.hooks.progressListener import ProgressListener
10
+
11
+ import whisper
12
+ from whisper import Whisper
13
+
14
+ from src.config import ModelConfig, VadInitialPromptMode
15
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
16
+
17
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
18
+ from src.utils import download_file
19
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
20
+
21
+ class WhisperContainer(AbstractWhisperContainer):
22
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
23
+ download_root: str = None,
24
+ cache: ModelCache = None, models: List[ModelConfig] = []):
25
+ if device is None:
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ super().__init__(model_name, device, compute_type, download_root, cache, models)
28
+
29
+ def ensure_downloaded(self):
30
+ """
31
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
32
+ passing the container to a subprocess.
33
+ """
34
+ # Warning: Using private API here
35
+ try:
36
+ root_dir = self.download_root
37
+ model_config = self._get_model_config()
38
+
39
+ if root_dir is None:
40
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
41
+
42
+ if self.model_name in whisper._MODELS:
43
+ whisper._download(whisper._MODELS[self.model_name], root_dir, False)
44
+ else:
45
+ # If the model is not in the official list, see if it needs to be downloaded
46
+ model_config.download_url(root_dir)
47
+ return True
48
+
49
+ except Exception as e:
50
+ # Given that the API is private, it could change at any time. We don't want to crash the program
51
+ print("Error pre-downloading model: " + str(e))
52
+ return False
53
+
54
+ def _get_model_config(self) -> ModelConfig:
55
+ """
56
+ Get the model configuration for the model.
57
+ """
58
+ for model in self.models:
59
+ if model.name == self.model_name:
60
+ return model
61
+ return None
62
+
63
+ def _create_model(self):
64
+ print("Loading whisper model " + self.model_name)
65
+ model_config = self._get_model_config()
66
+
67
+ # Note that the model will not be downloaded in the case of an official Whisper model
68
+ model_path = self._get_model_path(model_config, self.download_root)
69
+
70
+ return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
71
+
72
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
73
+ initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
74
+ **decodeOptions: dict) -> AbstractWhisperCallback:
75
+ """
76
+ Create a WhisperCallback object that can be used to transcript audio files.
77
+
78
+ Parameters
79
+ ----------
80
+ language: str
81
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
82
+ task: str
83
+ The task - either translate or transcribe.
84
+ initial_prompt: str
85
+ The initial prompt to use for the transcription.
86
+ initial_prompt_mode: VadInitialPromptMode
87
+ The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
88
+ If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
89
+ decodeOptions: dict
90
+ Additional options to pass to the decoder. Must be pickleable.
91
+
92
+ Returns
93
+ -------
94
+ A WhisperCallback object.
95
+ """
96
+ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode, **decodeOptions)
97
+
98
+ def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
99
+ from src.conversion.hf_converter import convert_hf_whisper
100
+ """
101
+ Download the model.
102
+
103
+ Parameters
104
+ ----------
105
+ model_config: ModelConfig
106
+ The model configuration.
107
+ """
108
+ # See if path is already set
109
+ if model_config.path is not None:
110
+ return model_config.path
111
+
112
+ if root_dir is None:
113
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
114
+
115
+ model_type = model_config.type.lower() if model_config.type is not None else "whisper"
116
+
117
+ if model_type in ["huggingface", "hf"]:
118
+ model_config.path = model_config.url
119
+ destination_target = os.path.join(root_dir, model_config.name + ".pt")
120
+
121
+ # Convert from HuggingFace format to Whisper format
122
+ if os.path.exists(destination_target):
123
+ print(f"File {destination_target} already exists, skipping conversion")
124
+ else:
125
+ print("Saving HuggingFace model in Whisper format to " + destination_target)
126
+ convert_hf_whisper(model_config.url, destination_target)
127
+
128
+ model_config.path = destination_target
129
+
130
+ elif model_type in ["whisper", "w"]:
131
+ model_config.path = model_config.url
132
+
133
+ # See if URL is just a file
134
+ if model_config.url in whisper._MODELS:
135
+ # No need to download anything - Whisper will handle it
136
+ model_config.path = model_config.url
137
+ elif model_config.url.startswith("file://"):
138
+ # Get file path
139
+ model_config.path = urlparse(model_config.url).path
140
+ # See if it is an URL
141
+ elif model_config.url.startswith("http://") or model_config.url.startswith("https://"):
142
+ # Extension (or file name)
143
+ extension = os.path.splitext(model_config.url)[-1]
144
+ download_target = os.path.join(root_dir, model_config.name + extension)
145
+
146
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
147
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
148
+
149
+ if not os.path.isfile(download_target):
150
+ download_file(model_config.url, download_target)
151
+ else:
152
+ print(f"File {download_target} already exists, skipping download")
153
+
154
+ model_config.path = download_target
155
+ # Must be a local file
156
+ else:
157
+ model_config.path = model_config.url
158
+
159
+ else:
160
+ raise ValueError(f"Unknown model type {model_type}")
161
+
162
+ return model_config.path
163
+
164
+ class WhisperCallback(AbstractWhisperCallback):
165
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None,
166
+ initial_prompt_mode: VadInitialPromptMode=VadInitialPromptMode.PREPREND_FIRST_SEGMENT, **decodeOptions: dict):
167
+ self.model_container = model_container
168
+ self.language = language
169
+ self.task = task
170
+ self.initial_prompt = initial_prompt
171
+ self.initial_prompt_mode = initial_prompt_mode
172
+ self.decodeOptions = decodeOptions
173
+
174
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
175
+ """
176
+ Peform the transcription of the given audio file or data.
177
+
178
+ Parameters
179
+ ----------
180
+ audio: Union[str, np.ndarray, torch.Tensor]
181
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
182
+ segment_index: int
183
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
184
+ task: str
185
+ The task - either translate or transcribe.
186
+ progress_listener: ProgressListener
187
+ A callback to receive progress updates.
188
+ """
189
+ model = self.model_container.get_model()
190
+
191
+ if progress_listener is not None:
192
+ with create_progress_listener_handle(progress_listener):
193
+ return self._transcribe(model, audio, segment_index, prompt, detected_language)
194
+ else:
195
+ return self._transcribe(model, audio, segment_index, prompt, detected_language)
196
+
197
+ def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
198
+ decodeOptions = self.decodeOptions.copy()
199
+
200
+ # Add fp16
201
+ if self.model_container.compute_type in ["fp16", "float16"]:
202
+ decodeOptions["fp16"] = True
203
+
204
+ initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
205
+
206
+ return model.transcribe(audio, \
207
+ language=self.language if self.language else detected_language, task=self.task, \
208
+ initial_prompt=initial_prompt, \
209
+ **decodeOptions
210
+ )
src/whisper/whisperFactory.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from src import modelCache
3
+ from src.config import ModelConfig
4
+ from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
5
+
6
+ def create_whisper_container(whisper_implementation: str,
7
+ model_name: str, device: str = None, compute_type: str = "float16",
8
+ download_root: str = None,
9
+ cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
10
+ print("Creating whisper container for " + whisper_implementation)
11
+
12
+ if (whisper_implementation == "whisper"):
13
+ from src.whisper.whisperContainer import WhisperContainer
14
+ return WhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
15
+ elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
16
+ from src.whisper.fasterWhisperContainer import FasterWhisperContainer
17
+ return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
18
+ else:
19
+ raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
tests/segments_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import unittest
3
+
4
+ sys.path.append('../whisper-webui')
5
+
6
+ from src.segments import merge_timestamps
7
+
8
+ class TestSegments(unittest.TestCase):
9
+ def __init__(self, *args, **kwargs):
10
+ super(TestSegments, self).__init__(*args, **kwargs)
11
+
12
+ def test_merge_segments(self):
13
+ segments = [
14
+ {'start': 10.0, 'end': 20.0},
15
+ {'start': 22.0, 'end': 27.0},
16
+ {'start': 31.0, 'end': 35.0},
17
+ {'start': 45.0, 'end': 60.0},
18
+ {'start': 61.0, 'end': 65.0},
19
+ {'start': 68.0, 'end': 98.0},
20
+ {'start': 100.0, 'end': 102.0},
21
+ {'start': 110.0, 'end': 112.0}
22
+ ]
23
+
24
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
25
+
26
+ self.assertListEqual(result, [
27
+ {'start': 9.0, 'end': 36.0},
28
+ {'start': 44.0, 'end': 66.0},
29
+ {'start': 67.0, 'end': 99.0},
30
+ {'start': 99.0, 'end': 103.0},
31
+ {'start': 109.0, 'end': 113.0}
32
+ ])
33
+
34
+ def test_overlap_next(self):
35
+ segments = [
36
+ {'start': 5.0, 'end': 39.182},
37
+ {'start': 39.986, 'end': 40.814}
38
+ ]
39
+
40
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
41
+
42
+ self.assertListEqual(result, [
43
+ {'start': 4.0, 'end': 39.584},
44
+ {'start': 39.584, 'end': 41.814}
45
+ ])
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()
tests/vad_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import unittest
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append('../whisper-webui')
7
+
8
+ from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
9
+
10
+ class TestVad(unittest.TestCase):
11
+ def __init__(self, *args, **kwargs):
12
+ super(TestVad, self).__init__(*args, **kwargs)
13
+ self.transcribe_calls = []
14
+
15
+ def test_transcript(self):
16
+ mock = MockVadTranscription()
17
+
18
+ self.transcribe_calls.clear()
19
+ result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
+
21
+ self.assertListEqual(self.transcribe_calls, [
22
+ [30, 30],
23
+ [100, 100]
24
+ ])
25
+
26
+ self.assertListEqual(result['segments'],
27
+ [{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
28
+ {'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
29
+ )
30
+
31
+ def transcribe_segments(self, segment):
32
+ self.transcribe_calls.append(segment.tolist())
33
+
34
+ # Dummy text
35
+ return {
36
+ 'text': "Hello world ",
37
+ 'segments': [
38
+ {
39
+ "start": 10.0,
40
+ "end": 20.0,
41
+ "text": "Hello world "
42
+ }
43
+ ],
44
+ 'language': ""
45
+ }
46
+
47
+ class MockVadTranscription(AbstractTranscription):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
+ start_time_seconds = float(start_time.removesuffix("s"))
53
+ duration_seconds = float(duration.removesuffix("s"))
54
+
55
+ # For mocking, this just returns a simple numppy array
56
+ return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
+
58
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
59
+ result = []
60
+
61
+ result.append( { 'start': 30, 'end': 60 } )
62
+ result.append( { 'start': 100, 'end': 200 } )
63
+ return result
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()