|
|
|
|
|
import json |
|
import string |
|
|
|
_punctuation = "".join(c for c in string.punctuation if c not in ["-", "'"]) + "。,!?:”、…" |
|
|
|
def split_long_segments(segments, max_length, use_space = True): |
|
new_segments = [] |
|
for segment in segments: |
|
text = segment["text"] |
|
if len(text) <= max_length: |
|
new_segments.append(segment) |
|
else: |
|
meta_words = segment["words"] |
|
|
|
if use_space: |
|
|
|
words = text.split() |
|
else: |
|
words = [w["text"] for w in meta_words] |
|
if len(words) != len(meta_words): |
|
new_words = [w["text"] for w in meta_words] |
|
print(f"WARNING: {' '.join(words)} != {' '.join(new_words)}") |
|
words = new_words |
|
current_text = "" |
|
current_start = segment["start"] |
|
current_best_idx = None |
|
current_best_end = None |
|
current_best_next_start = None |
|
for i, (word, meta) in enumerate(zip(words, meta_words)): |
|
current_text_before = current_text |
|
if current_text and use_space: |
|
current_text += " " |
|
current_text += word |
|
|
|
if len(current_text) > max_length and len(current_text_before): |
|
start = current_start |
|
if current_best_idx is not None: |
|
text = current_text[:current_best_idx] |
|
end = current_best_end |
|
current_text = current_text[current_best_idx+1:] |
|
current_start = current_best_next_start |
|
else: |
|
text = current_text_before |
|
end = meta_words[i-1]["end"] |
|
current_text = word |
|
current_start = meta["start"] |
|
|
|
current_best_idx = None |
|
current_best_end = None |
|
current_best_next_start = None |
|
|
|
new_segments.append({"text": text, "start": start, "end": end}) |
|
|
|
|
|
if current_text and current_text[-1] in _punctuation: |
|
current_best_idx = len(current_text) |
|
current_best_end = meta["end"] |
|
current_best_next_start = meta_words[i+1]["start"] if i+1 < len(meta_words) else None |
|
|
|
if len(current_text): |
|
new_segments.append({"text": current_text, "start": current_start, "end": segment["end"]}) |
|
|
|
return new_segments |
|
|
|
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): |
|
assert seconds >= 0, "non-negative timestamp expected" |
|
milliseconds = round(seconds * 1000.0) |
|
|
|
hours = milliseconds // 3_600_000 |
|
milliseconds -= hours * 3_600_000 |
|
|
|
minutes = milliseconds // 60_000 |
|
milliseconds -= minutes * 60_000 |
|
|
|
seconds = milliseconds // 1_000 |
|
milliseconds -= seconds * 1_000 |
|
|
|
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" |
|
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" |
|
|
|
def write_vtt(result, file): |
|
print("WEBVTT\n", file=file) |
|
for segment in result: |
|
print( |
|
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" |
|
f"{segment['text'].strip().replace('-->', '->')}\n", |
|
file=file, |
|
flush=True, |
|
) |
|
|
|
def write_srt(result, file): |
|
for i, segment in enumerate(result, start=1): |
|
|
|
print( |
|
f"{i}\n" |
|
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " |
|
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" |
|
f"{segment['text'].strip().replace('-->', '->')}\n", |
|
file=file, |
|
flush=True, |
|
) |
|
|
|
def cli(): |
|
|
|
import os |
|
import argparse |
|
|
|
supported_formats = ["srt", "vtt"] |
|
|
|
parser = argparse.ArgumentParser( |
|
description='Convert .word.json transcription files (output of whisper_timestamped) to srt or vtt, being able to cut long segments', |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
parser.add_argument('input', type=str, help='Input json file, or input folder') |
|
parser.add_argument('output', type=str, help='Output srt or vtt file, or output folder') |
|
parser.add_argument('--max_length', default=200, help='Maximum length of a segment in characters', type=int) |
|
parser.add_argument('--format', type=str, default="all", help='Output format (if the output is a folder, i.e. not a file with an explicit extension)', choices= supported_formats + ["all"]) |
|
args = parser.parse_args() |
|
|
|
if os.path.isdir(args.input) or not max([args.output.endswith(e) for e in supported_formats]): |
|
input_files = [f for f in os.listdir(args.input) if f.endswith(".words.json")] if os.path.isdir(args.input) else [os.path.basename(args.input)] |
|
extensions = [args.format] if args.format != "all" else ["srt", "vtt"] |
|
output_files = [[os.path.join(args.output, f[:-11] + "." + e) for e in extensions] for f in input_files] |
|
if os.path.isdir(args.input): |
|
input_files = [os.path.join(args.input, f) for f in input_files] |
|
else: |
|
input_files = [args.input] |
|
if not os.path.isdir(args.output): |
|
os.makedirs(args.output) |
|
else: |
|
input_files = [args.input] |
|
output_files = [[args.output]] |
|
if not os.path.isdir(os.path.dirname(args.output)): |
|
os.makedirs(os.path.dirname(args.output)) |
|
|
|
for fn, outputs in zip(input_files, output_files): |
|
with open(fn, "r", encoding="utf-8") as f: |
|
transcript = json.load(f) |
|
segments = transcript["segments"] |
|
if args.max_length: |
|
language = transcript["language"] |
|
use_space = language not in ["zh", "ja", "th", "lo", "my"] |
|
segments = split_long_segments(segments, args.max_length, use_space=use_space) |
|
for output in outputs: |
|
if output.endswith(".srt"): |
|
with open(output, "w", encoding="utf-8") as f: |
|
write_srt(segments, file=f) |
|
elif output.endswith(".vtt"): |
|
with open(output, "w", encoding="utf-8") as f: |
|
write_vtt(segments, file=f) |
|
else: |
|
raise RuntimeError(f"Unknown output format for {output}") |
|
|
|
if __name__ == "__main__": |
|
cli() |