File size: 4,202 Bytes
262d511
 
 
 
 
 
0cc2cbd
 
262d511
 
 
 
0cc2cbd
 
 
 
 
 
 
 
d6d4252
0cc2cbd
 
 
 
 
 
 
262d511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc2cbd
 
 
 
 
 
 
 
 
262d511
 
0cc2cbd
 
 
1221d26
0cc2cbd
262d511
 
 
 
 
 
 
 
 
 
 
 
 
 
048700d
262d511
 
 
 
 
 
 
 
 
 
0cc2cbd
 
 
 
 
262d511
 
 
 
 
 
 
 
 
 
e7cb8eb
262d511
79eccc5
e9c72a6
0cc2cbd
 
262d511
 
 
bc815d6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os, sys, re
import shutil
import subprocess
import soundfile
from process_audio import segment_audio
from write_srt import write_to_file
from clean_text import clean_english, clean_german, clean_spanish
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import gradio as gr


english_model = "facebook/wav2vec2-large-960h-lv60-self"
english_tokenizer = Wav2Vec2Processor.from_pretrained(english_model)
english_asr_model = Wav2Vec2ForCTC.from_pretrained(english_model)

german_model = "jonatasgrosman/wav2vec2-large-xlsr-53-german"
german_tokenizer = Wav2Vec2Processor.from_pretrained(german_model)
german_asr_model = Wav2Vec2ForCTC.from_pretrained(german_model)

spanish_model = "patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm"
spanish_tokenizer = Wav2Vec2Processor.from_pretrained(spanish_model)
spanish_asr_model = Wav2Vec2ForCTC.from_pretrained(spanish_model)

# Get German corpus and update nltk
command = ["python", "-m", "textblob.download_corpora"]
subprocess.run(command)


# Line count for SRT file
line_count = 0

def sort_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] 
    
    return sorted(data, key = alphanum_key)

def transcribe_audio(tokenizer, asr_model, audio_file, file_handle):  
    # Run Wav2Vec2.0 inference on each audio file generated after VAD segmentation.
    global line_count
    
    speech, rate = soundfile.read(audio_file) 
    input_values = tokenizer(speech, sampling_rate=16000, return_tensors = "pt", padding='longest').input_values
    logits = asr_model(input_values).logits
    prediction = torch.argmax(logits, dim = -1)

    
    infered_text = tokenizer.batch_decode(prediction)[0].lower()
    if len(infered_text) > 1:
        if lang == 'english':
            infered_text = clean_english(infered_text)
        elif lang == 'german':
            infered_text = clean_german(infered_text)
        elif lang == 'spanish':
            infered_text = clean_spanish(infered_text)

        print(infered_text)
        limits = audio_file.split(os.sep)[-1][:-4].split("_")[-1].split("-")
        line_count += 1
        write_to_file(file_handle, infered_text, line_count, limits)
    else:
        infered_text = ''

        
def get_subs(input_file, language):
    # Get directory for audio
    base_directory = os.getcwd()
    audio_directory = os.path.join(base_directory, "audio")
    if os.path.isdir(audio_directory):
        shutil.rmtree(audio_directory)
    os.mkdir(audio_directory)
    
    # Extract audio from video file
    video_file = input_file
    audio_file = audio_directory+'/temp.wav'
    command = ["ffmpeg", "-i", video_file, "-ac", "1", "-ar", "16000","-vn", "-f", "wav", audio_file]
    subprocess.run(command)
    
    video_file = input_file.split('/')[-1][:-4]
    srt_file_name = os.path.join(video_file + ".srt")
    
    # Split audio file based on VAD silent segments
    segment_audio(audio_file)
    os.remove(audio_file) 
    
    # Output SRT file
    file_handle = open(srt_file_name, "a+")
    file_handle.seek(0)
    for file in sort_alphanumeric(os.listdir(audio_directory)):
        audio_segment_path = os.path.join(audio_directory, file)
        global lang
        lang = language.lower()
        tokenizer = globals()[lang+'_tokenizer']
        asr_model = globals()[lang+'_asr_model']

        if audio_segment_path.split(os.sep)[-1] != audio_file.split(os.sep)[-1]:
            transcribe_audio(tokenizer, asr_model, audio_segment_path, file_handle)

    file_handle.close()
    shutil.rmtree(audio_directory)    

    return srt_file_name


gradio_ui = gr.Interface(
    enable_queue=True,
    fn=get_subs,
    title="Video to Subtitle",
    description="Get subtitles (SRT file) for your videos. Inference speed is about 10s/per 1min of video BUT the speed of uploading your video depends on your internet connection.",
    inputs=[gr.inputs.Video(label="Upload Video File"),
        gr.inputs.Radio(label="Choose Language", choices=['English', 'German', 'Spanish'])],
    outputs=gr.outputs.File(label="Auto-Transcript")
    )

gradio_ui.launch()