File size: 8,559 Bytes
bbef1a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a37dab
4e9b286
01a49c3
4e9b286
 
41d06ba
01a49c3
4e9b286
01a49c3
4e9b286
 
 
 
 
 
 
c4d6bf6
01a49c3
 
 
d531709
4e9b286
 
 
4a37dab
 
 
c4d6bf6
 
4a37dab
4e9b286
 
5e3f570
4a37dab
ee1db21
4a37dab
01a49c3
 
 
 
 
4e9b286
01a49c3
 
 
 
 
 
 
 
4e9b286
891e37e
 
 
 
 
 
 
 
 
 
 
c4d6bf6
4a37dab
c4d6bf6
891e37e
 
 
2084afa
891e37e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e9b286
 
891e37e
 
c4d6bf6
01a49c3
 
891e37e
5f58cac
 
 
 
 
4e9b286
01a49c3
63b59c5
4e9b286
 
5f58cac
 
556b4ae
2084afa
01a49c3
4e9b286
 
 
 
 
01a49c3
4e9b286
 
 
 
01a49c3
4e9b286
 
 
 
 
 
 
 
556b4ae
4e9b286
 
 
 
 
556b4ae
d531709
eb02780
d531709
 
4e9b286
d531709
 
 
 
c4d6bf6
4e9b286
c4d6bf6
4e9b286
c4d6bf6
 
 
 
4e9b286
 
 
 
 
 
 
 
 
 
 
 
d531709
 
 
 
4e9b286
 
 
 
 
 
 
2084afa
4e9b286
c4d6bf6
4e9b286
d531709
 
c4d6bf6
2084afa
63b59c5
4e9b286
01a49c3
eb02780
 
 
 
 
 
 
 
01a49c3
4e9b286
01a49c3
4e9b286
 
 
 
01a49c3
 
 
4e9b286
 
 
 
 
 
 
 
 
 
 
01a49c3
4e9b286
 
 
 
 
 
 
 
 
 
 
01a49c3
4e9b286
 
d531709
d9f5363
d531709
4e9b286
 
 
 
 
 
5f58cac
4e9b286
 
 
 
d5ade87
4e9b286
 
 
c3d833e
4e9b286
eb02780
31fe9de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import logging

# Configure the root logger to WARNING to suppress debug messages from other libraries
logging.basicConfig(level=logging.WARNING)

# Create a console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)

# Create a formatter
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)

# Configure the logger for your specific library
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)


import base64
import io
import os
import tempfile
import time
import traceback
from dataclasses import dataclass
from queue import Queue
from threading import Thread

import gradio as gr
import librosa
import numpy as np
import requests
from gradio_webrtc import StreamHandler, WebRTC
from huggingface_hub import snapshot_download
from pydub import AudioSegment
from twilio.rest import Client

from server import serve

# from server import serve
from utils.vad import VadOptions, collect_chunks, get_speech_timestamps

repo_id = "gpt-omni/mini-omni"
snapshot_download(repo_id, local_dir="./checkpoint", revision="main")

IP = "0.0.0.0"
PORT = 60808

thread = Thread(target=serve, daemon=True)
thread.start()


API_URL = "https://freddyaboulton-omni-mini-webrtc-backend.hf.space/chat"

account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")

if account_sid and auth_token:
    client = Client(account_sid, auth_token)

    token = client.tokens.create()

    rtc_configuration = {
        "iceServers": token.ice_servers,
        "iceTransportPolicy": "relay",
    }
else:
    rtc_configuration = None

# recording parameters
IN_CHANNELS = 1
IN_RATE = 24000
IN_CHUNK = 1024
IN_SAMPLE_WIDTH = 2
VAD_STRIDE = 0.5

# playing parameters
OUT_CHANNELS = 1
OUT_RATE = 24000
OUT_SAMPLE_WIDTH = 2
OUT_CHUNK = 20 * 4096


def run_vad(ori_audio, sr):
    _st = time.time()
    try:
        audio = ori_audio
        audio = audio.astype(np.float32) / 32768.0
        sampling_rate = 16000
        if sr != sampling_rate:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)

        vad_parameters = {}
        vad_parameters = VadOptions(**vad_parameters)
        speech_chunks = get_speech_timestamps(audio, vad_parameters)
        audio = collect_chunks(audio, speech_chunks)
        duration_after_vad = audio.shape[0] / sampling_rate

        if sr != sampling_rate:
            # resample to original sampling rate
            vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
        else:
            vad_audio = audio
        vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
        vad_audio_bytes = vad_audio.tobytes()

        return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
    except Exception as e:
        msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
        print(msg)
        return -1, ori_audio, round(time.time() - _st, 4)


def warm_up():
    frames = np.zeros((1, 1600))  # 1024 frames of 2 bytes each
    _, frames, tcost = run_vad(frames, 16000)
    print(f"warm up done, time_cost: {tcost:.3f} s")


# warm_up()


@dataclass
class AppState:
    stream: np.ndarray | None = None
    sampling_rate: int = 0
    pause_detected: bool = False
    started_talking: bool = False
    responding: bool = False
    stopped: bool = False
    buffer: np.ndarray | None = None


def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
    """Take in the stream, determine if a pause happened"""
    duration = len(audio) / sampling_rate

    dur_vad, _, _ = run_vad(audio, sampling_rate)

    if duration >= 0.60:
        if dur_vad > 0.2 and not state.started_talking:
            print("started talking")
            state.started_talking = True
        if state.started_talking:
            if state.stream is None:
                state.stream = audio
            else:
                state.stream = np.concatenate((state.stream, audio))
        state.buffer = None
        if dur_vad < 0.1 and state.started_talking:
            segment = AudioSegment(
                state.stream.tobytes(),
                frame_rate=sampling_rate,
                sample_width=audio.dtype.itemsize,
                channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
            )

            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                segment.export(f.name, format="wav")
            print("input file written", f.name)
            return True
    return False


def speaking(audio_bytes: str):
    base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
    files = {"audio": base64_encoded}
    byte_buffer = b""
    with requests.post(API_URL, json=files, stream=True) as response:
        try:
            for chunk in response.iter_content(chunk_size=OUT_CHUNK):
                if chunk:
                    # Create an audio segment from the numpy array
                    byte_buffer += chunk
                    audio_segment = AudioSegment(
                        chunk + b"\x00" if len(chunk) % 2 != 0 else chunk,
                        frame_rate=OUT_RATE,
                        sample_width=OUT_SAMPLE_WIDTH,
                        channels=OUT_CHANNELS,
                    )
                    # Export the audio segment to a numpy array
                    audio_np = np.array(audio_segment.get_array_of_samples())
                    yield audio_np.reshape(1, -1)
            all_output_audio = AudioSegment(
                byte_buffer,
                frame_rate=OUT_RATE,
                sample_width=OUT_SAMPLE_WIDTH,
                channels=1,
            )
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                all_output_audio.export(f.name, format="wav")
                print("output file written", f.name)
        except Exception as e:
            raise gr.Error(f"Error during audio streaming: {e}")


def process_audio(audio: tuple, state: AppState) -> None:
    frame_rate, array = audio
    array = np.squeeze(array)
    if not state.sampling_rate:
        state.sampling_rate = frame_rate
    if state.buffer is None:
        state.buffer = array
    else:
        state.buffer = np.concatenate((state.buffer, array))

    pause_detected = determine_pause(state.buffer, state.sampling_rate, state)
    state.pause_detected = pause_detected


def response(state: AppState):
    if not state.pause_detected and not state.started_talking:
        return None

    audio_buffer = io.BytesIO()
    segment = AudioSegment(
        state.stream.tobytes(),
        frame_rate=state.sampling_rate,
        sample_width=state.stream.dtype.itemsize,
        channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
    )
    segment.export(audio_buffer, format="wav")

    for numpy_array in speaking(audio_buffer.getvalue()):
        yield (OUT_RATE, numpy_array, "mono")


class OmniHandler(StreamHandler):
    def __init__(self) -> None:
        super().__init__(
            expected_layout="mono", output_sample_rate=OUT_RATE, output_frame_size=480
        )
        self.chunk_queue = Queue()
        self.state = AppState()
        self.generator = None
        self.duration = 0

    def receive(self, frame: tuple[int, np.ndarray]) -> None:
        if self.state.responding:
            return
        process_audio(frame, self.state)
        if self.state.pause_detected:
            self.chunk_queue.put(True)

    def reset(self):
        self.generator = None
        self.state = AppState()
        self.duration = 0

    def emit(self):
        if not self.generator:
            self.chunk_queue.get()
            self.state.responding = True
            self.generator = response(self.state)
        try:
            return next(self.generator)
        except StopIteration:
            self.reset()


with gr.Blocks() as demo:
    gr.HTML(
        """
    <h1 style='text-align: center'>
    Omni Chat (Powered by WebRTC ⚡️)
    </h1>
    """
    )
    with gr.Column():
        with gr.Group():
            audio = WebRTC(
                label="Stream",
                rtc_configuration=rtc_configuration,
                mode="send-receive",
                modality="audio",
            )
        audio.stream(fn=OmniHandler(), inputs=[audio], outputs=[audio], time_limit=60)


demo.launch(ssr_mode=False)