|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from threading import Thread |
|
import time |
|
import base64 |
|
import numpy as np |
|
import requests |
|
import traceback |
|
from dataclasses import dataclass |
|
import io |
|
from pydub import AudioSegment |
|
import librosa |
|
from utils.vad import get_speech_timestamps, collect_chunks, VadOptions |
|
|
|
|
|
from server import serve |
|
|
|
repo_id = "gpt-omni/mini-omni" |
|
snapshot_download(repo_id, local_dir="./checkpoint", revision="main") |
|
|
|
IP = "0.0.0.0" |
|
PORT = 60808 |
|
|
|
thread = Thread(target=serve, daemon=True) |
|
thread.start() |
|
|
|
API_URL = "http://0.0.0.0:60808/chat" |
|
|
|
|
|
IN_CHANNELS = 1 |
|
IN_RATE = 24000 |
|
IN_CHUNK = 1024 |
|
IN_SAMPLE_WIDTH = 2 |
|
VAD_STRIDE = 0.5 |
|
|
|
|
|
OUT_CHANNELS = 1 |
|
OUT_RATE = 24000 |
|
OUT_SAMPLE_WIDTH = 2 |
|
OUT_CHUNK = 5760 |
|
|
|
|
|
OUT_CHUNK = 20 * 4096 |
|
OUT_RATE = 24000 |
|
OUT_CHANNELS = 1 |
|
|
|
|
|
def run_vad(ori_audio, sr): |
|
_st = time.time() |
|
try: |
|
audio = ori_audio |
|
audio = audio.astype(np.float32) / 32768.0 |
|
sampling_rate = 16000 |
|
if sr != sampling_rate: |
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) |
|
|
|
vad_parameters = {} |
|
vad_parameters = VadOptions(**vad_parameters) |
|
speech_chunks = get_speech_timestamps(audio, vad_parameters) |
|
audio = collect_chunks(audio, speech_chunks) |
|
duration_after_vad = audio.shape[0] / sampling_rate |
|
|
|
if sr != sampling_rate: |
|
|
|
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) |
|
else: |
|
vad_audio = audio |
|
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) |
|
vad_audio_bytes = vad_audio.tobytes() |
|
|
|
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) |
|
except Exception as e: |
|
msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}" |
|
print(msg) |
|
return -1, ori_audio, round(time.time() - _st, 4) |
|
|
|
|
|
def warm_up(): |
|
frames = b"\x00\x00" * 1024 * 2 |
|
dur, frames, tcost = run_vad(frames, 16000) |
|
print(f"warm up done, time_cost: {tcost:.3f} s") |
|
|
|
|
|
warm_up() |
|
|
|
|
|
@dataclass |
|
class AppState: |
|
stream: np.ndarray | None = None |
|
sampling_rate: int = 0 |
|
pause_detected: bool = False |
|
started_talking = False |
|
|
|
|
|
def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: |
|
"""Take in the stream, determine if a pause happened""" |
|
|
|
temp_audio = audio |
|
|
|
dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate) |
|
duration = len(audio) / sampling_rate |
|
|
|
if dur_vad > 0.5 and not state.started_talking: |
|
print("started talking") |
|
state.started_talking = True |
|
return False |
|
|
|
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s") |
|
|
|
return (duration - dur_vad) > 1 |
|
|
|
|
|
def speaking(audio: np.ndarray, sampling_rate: int): |
|
audio_buffer = io.BytesIO() |
|
|
|
segment = AudioSegment( |
|
audio.tobytes(), |
|
frame_rate=sampling_rate, |
|
sample_width=audio.dtype.itemsize, |
|
channels=(1 if len(audio.shape) == 1 else audio.shape[1]), |
|
) |
|
segment.export(audio_buffer, format="wav") |
|
|
|
with open("input_audio.wav", "wb") as f: |
|
f.write(audio_buffer.getvalue()) |
|
|
|
audio_bytes = audio_buffer.getvalue() |
|
|
|
base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8") |
|
files = {"audio": base64_encoded} |
|
with requests.post(API_URL, json=files, stream=True) as response: |
|
try: |
|
for chunk in response.iter_content(chunk_size=OUT_CHUNK): |
|
if chunk: |
|
|
|
audio_segment = AudioSegment( |
|
chunk, |
|
frame_rate=OUT_RATE, |
|
sample_width=OUT_SAMPLE_WIDTH, |
|
channels=OUT_CHANNELS, |
|
) |
|
|
|
|
|
mp3_io = io.BytesIO() |
|
audio_segment.export(mp3_io, format="mp3", bitrate="320k") |
|
|
|
|
|
mp3_bytes = mp3_io.getvalue() |
|
mp3_io.close() |
|
yield mp3_bytes |
|
|
|
except Exception as e: |
|
raise gr.Error(f"Error during audio streaming: {e}") |
|
|
|
|
|
|
|
|
|
def process_audio(audio: tuple, state: AppState): |
|
if state.stream is None: |
|
state.stream = audio[1] |
|
state.sampling_rate = audio[0] |
|
else: |
|
state.stream = np.concatenate((state.stream, audio[1])) |
|
|
|
pause_detected = determine_pause(state.stream, state.sampling_rate, state) |
|
state.pause_detected = pause_detected |
|
|
|
if state.pause_detected and state.started_talking: |
|
return gr.Audio(recording=False), state |
|
return None, state |
|
|
|
|
|
def response(state: AppState): |
|
if not state.pause_detected: |
|
return None, AppState() |
|
|
|
for mp3_bytes in speaking(state.stream, state.sampling_rate): |
|
yield mp3_bytes, state |
|
|
|
yield None, AppState() |
|
|
|
|
|
def start_recording_user(state: AppState): |
|
if state.pause_detected and state.started_talking: |
|
return gr.Audio(recording=True) |
|
return None |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_audio = gr.Audio( |
|
label="Input Audio", sources="microphone", type="numpy" |
|
) |
|
with gr.Column(): |
|
output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True) |
|
state = gr.State(value=AppState()) |
|
|
|
stream = input_audio.stream( |
|
process_audio, |
|
[input_audio, state], |
|
[input_audio, state], |
|
stream_every=0.5, |
|
time_limit=30, |
|
) |
|
respond = input_audio.stop_recording( |
|
response, |
|
[state], |
|
[output_audio, state] |
|
) |
|
restart = output_audio.stop( |
|
start_recording_user, |
|
[state], |
|
[input_audio] |
|
) |
|
cancel = gr.Button("Stop Conversation", variant="stop") |
|
cancel.click(lambda: AppState(), None, [state], cancels=[respond, restart]) |
|
|
|
|
|
demo.launch() |
|
|