File size: 4,971 Bytes
16a3aec
 
2142103
0b6a418
2142103
 
 
d147382
6a51822
16a3aec
dcf6807
 
16a3aec
d147382
 
16a3aec
 
d147382
 
16a3aec
 
 
d147382
16a3aec
d147382
 
 
 
 
16a3aec
c9d0859
0d83c55
 
 
 
 
 
 
 
 
1b6d8ab
 
 
 
 
 
 
6a51822
1b6d8ab
6a51822
1b6d8ab
6a51822
 
 
 
bd5ca33
6a51822
 
 
 
 
 
0d83c55
 
6a51822
 
 
 
 
 
 
 
 
0d83c55
6a51822
0d83c55
3e79246
 
 
 
6a51822
 
 
16a3aec
0d83c55
 
aa5cca4
d147382
 
 
 
 
 
 
 
 
16a3aec
d147382
6a51822
 
 
 
 
 
0d83c55
 
 
 
 
6a51822
d147382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a51822
d147382
 
 
 
 
 
 
6a51822
 
d147382
0d83c55
 
 
 
 
 
 
 
d147382
0d83c55
 
 
 
 
 
 
 
 
 
 
 
16a3aec
6a51822
d147382
16a3aec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
import os
from dotenv import load_dotenv
from pydub import AudioSegment

load_dotenv()

from lang_list import TEXT_SOURCE_LANGUAGE_NAMES
from gradio_client import Client

HF_API = os.getenv("HF_API")
API_URL = os.getenv("API_URL")  # path to Seamlessm4t API endpoint

DEFAULT_TARGET_LANGUAGE = "Western Persian"

DESCRIPTION = """
# Seamlessm4t + Speaker Diarization + Voice Activity Detection
Here we use seamlessm4t to generate captions for full audios. Audio can be of arbitrary length. 
"""

DUPLICATE = """
To duplicate this repo, you have to give permission from three reopsitories and accept all user conditions: 

1- https://huggingface.co/pyannote/voice-activity-detection

2- https://hf.co/pyannote/segmentation

3- https://hf.co/pyannote/speaker-diarization

"""
from pyannote.audio import Pipeline

pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization", use_auth_token=HF_API
)


def predict(
    target_language, number_of_speakers, audio_source, input_audio_mic, input_audio_file
):
    if audio_source == "microphone":
        input_data = input_audio_mic
    else:
        input_data = input_audio_file

    print(input_data)

    if number_of_speakers == 0:
        diarization = pipeline(input_data)
    else:
        diarization = pipeline(input_data, num_speakers=number_of_speakers)

    for turn, value, speaker in diarization.itertracks(yield_label=True):
        print(f"start={turn.start}s stop={turn.end}s speaker_{speaker}")

    song = AudioSegment.from_wav(input_data)

    client = Client(API_URL)
    output_text = ""
    for turn, value, speaker in diarization.itertracks(yield_label=True):
        print(turn)
        try:
            clipped = song[turn.start * 1000 : turn.end * 1000]
            clipped.export(f"my.wav", format="wav", bitrate=16000)

            _, result = client.predict(
                "ASR (Automatic Speech Recognition)",
                "file",  # str in 'Audio source' Radio component
                f"my.wav",
                f"my.wav",
                "text",
                target_language,
                target_language,
                api_name="/run",
            )
            current_text = f"start: {turn.start:.1f} end: {turn.end:.1f} text: {result} speaker: {speaker}"

            if current_text is not None:
                output_text = output_text + "\n" + current_text
            yield output_text

        except Exception as e:
            print(e)

        # return output_text


def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
    mic = audio_source == "microphone"
    return (
        gr.update(visible=mic, value=None),  # input_audio_mic
        gr.update(visible=not mic, value=None),  # input_audio_file
    )


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Group():
        with gr.Row():
            target_language = gr.Dropdown(
                choices=TEXT_SOURCE_LANGUAGE_NAMES,
                label="Output Language",
                value=DEFAULT_TARGET_LANGUAGE,
                interactive=True,
                info="Select your target language",
            )
            number_of_speakers = gr.Number(
                label="Number of Speakers",
                info="Keep it zero, if you want the model to automatically detect the number of speakers",
            )
        with gr.Row() as audio_box:
            audio_source = gr.Radio(
                choices=["file", "microphone"], value="file", interactive=True
            )
            input_audio_mic = gr.Audio(
                label="Input speech",
                type="filepath",
                source="microphone",
                visible=False,
            )
            input_audio_file = gr.Audio(
                label="Input speech",
                type="filepath",
                source="upload",
                visible=True,
            )
            final_audio = gr.Audio(label="Output", visible=False)
        audio_source.change(
            fn=update_audio_ui,
            inputs=audio_source,
            outputs=[input_audio_mic, input_audio_file],
            queue=False,
            api_name=False,
        )
        input_audio_mic.change(lambda x: x, input_audio_mic, final_audio)
        input_audio_file.change(lambda x: x, input_audio_file, final_audio)
        submit = gr.Button("Submit")
        text_output = gr.Textbox(
            label="Transcribed Text",
            value="",
            interactive=False,
            lines=10,
            scale=10,
            max_lines=10,
        )

        submit.click(
            fn=predict,
            inputs=[
                target_language,
                number_of_speakers,
                audio_source,
                input_audio_mic,
                input_audio_file,
            ],
            outputs=[text_output],
            api_name="predict",
        )

    gr.Markdown(DUPLICATE)

demo.queue(max_size=50).launch()