BatuhanYilmaz
commited on
Commit
•
49acb19
1
Parent(s):
7e4e567
Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textwrap
|
2 |
+
import unicodedata
|
3 |
+
import re
|
4 |
+
|
5 |
+
import zlib
|
6 |
+
from typing import Iterator, TextIO
|
7 |
+
|
8 |
+
|
9 |
+
def exact_div(x, y):
|
10 |
+
assert x % y == 0
|
11 |
+
return x // y
|
12 |
+
|
13 |
+
|
14 |
+
def str2bool(string):
|
15 |
+
str2val = {"True": True, "False": False}
|
16 |
+
if string in str2val:
|
17 |
+
return str2val[string]
|
18 |
+
else:
|
19 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
20 |
+
|
21 |
+
|
22 |
+
def optional_int(string):
|
23 |
+
return None if string == "None" else int(string)
|
24 |
+
|
25 |
+
|
26 |
+
def optional_float(string):
|
27 |
+
return None if string == "None" else float(string)
|
28 |
+
|
29 |
+
|
30 |
+
def compression_ratio(text) -> float:
|
31 |
+
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
32 |
+
|
33 |
+
|
34 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
|
35 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
36 |
+
milliseconds = round(seconds * 1000.0)
|
37 |
+
|
38 |
+
hours = milliseconds // 3_600_000
|
39 |
+
milliseconds -= hours * 3_600_000
|
40 |
+
|
41 |
+
minutes = milliseconds // 60_000
|
42 |
+
milliseconds -= minutes * 60_000
|
43 |
+
|
44 |
+
seconds = milliseconds // 1_000
|
45 |
+
milliseconds -= seconds * 1_000
|
46 |
+
|
47 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
48 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
|
49 |
+
|
50 |
+
|
51 |
+
def write_txt(transcript: Iterator[dict], file: TextIO):
|
52 |
+
for segment in transcript:
|
53 |
+
print(segment['text'].strip(), file=file, flush=True)
|
54 |
+
|
55 |
+
|
56 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
57 |
+
print("WEBVTT\n", file=file)
|
58 |
+
for segment in transcript:
|
59 |
+
text = processText(segment['text'], maxLineWidth).replace('-->', '->')
|
60 |
+
|
61 |
+
print(
|
62 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
63 |
+
f"{text}\n",
|
64 |
+
file=file,
|
65 |
+
flush=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
70 |
+
"""
|
71 |
+
Write a transcript to a file in SRT format.
|
72 |
+
Example usage:
|
73 |
+
from pathlib import Path
|
74 |
+
from whisper.utils import write_srt
|
75 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
76 |
+
# save SRT
|
77 |
+
audio_basename = Path(audio_path).stem
|
78 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
79 |
+
write_srt(result["segments"], file=srt)
|
80 |
+
"""
|
81 |
+
for i, segment in enumerate(transcript, start=1):
|
82 |
+
text = processText(segment['text'].strip(), maxLineWidth).replace('-->', '->')
|
83 |
+
|
84 |
+
# write srt lines
|
85 |
+
print(
|
86 |
+
f"{i}\n"
|
87 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
88 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
|
89 |
+
f"{text}\n",
|
90 |
+
file=file,
|
91 |
+
flush=True,
|
92 |
+
)
|
93 |
+
|
94 |
+
def processText(text: str, maxLineWidth=None):
|
95 |
+
if (maxLineWidth is None or maxLineWidth < 0):
|
96 |
+
return text
|
97 |
+
|
98 |
+
lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
|
99 |
+
return '\n'.join(lines)
|
100 |
+
|
101 |
+
def slugify(value, allow_unicode=False):
|
102 |
+
"""
|
103 |
+
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
104 |
+
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
105 |
+
dashes to single dashes. Remove characters that aren't alphanumerics,
|
106 |
+
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
107 |
+
trailing whitespace, dashes, and underscores.
|
108 |
+
"""
|
109 |
+
value = str(value)
|
110 |
+
if allow_unicode:
|
111 |
+
value = unicodedata.normalize('NFKC', value)
|
112 |
+
else:
|
113 |
+
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
114 |
+
value = re.sub(r'[^\w\s-]', '', value.lower())
|
115 |
+
return re.sub(r'[-\s]+', '-', value).strip('-_')
|