import gradio as gr from huggingface_hub import snapshot_download from threading import Thread import time import base64 import numpy as np import requests import traceback from dataclasses import dataclass from pathlib import Path import io import wave import tempfile from pydub import AudioSegment import librosa from utils.vad import get_speech_timestamps, collect_chunks, VadOptions from server import serve repo_id = "gpt-omni/mini-omni" snapshot_download(repo_id, local_dir="./checkpoint", revision="main") IP = "0.0.0.0" PORT = 60808 thread = Thread(target=serve, daemon=True) thread.start() API_URL = "http://0.0.0.0:60808/chat" # recording parameters IN_CHANNELS = 1 IN_RATE = 24000 IN_CHUNK = 1024 IN_SAMPLE_WIDTH = 2 VAD_STRIDE = 0.5 # playing parameters OUT_CHANNELS = 1 OUT_RATE = 24000 OUT_SAMPLE_WIDTH = 2 OUT_CHUNK = 5760 OUT_CHUNK = 20 * 4096 OUT_RATE = 24000 OUT_CHANNELS = 1 def run_vad(ori_audio, sr): _st = time.time() try: audio = ori_audio audio = audio.astype(np.float32) / 32768.0 sampling_rate = 16000 if sr != sampling_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) vad_parameters = {} vad_parameters = VadOptions(**vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) audio = collect_chunks(audio, speech_chunks) duration_after_vad = audio.shape[0] / sampling_rate if sr != sampling_rate: # resample to original sampling rate vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) else: vad_audio = audio vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) vad_audio_bytes = vad_audio.tobytes() return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) except Exception as e: msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}" print(msg) return -1, ori_audio, round(time.time() - _st, 4) def warm_up(): frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each dur, frames, tcost = run_vad(frames, 16000) print(f"warm up done, time_cost: {tcost:.3f} s") warm_up() def determine_pause(audio: np.ndarray, sampling_rate: int) -> bool: """Take in the stream, determine if a pause happened""" temp_audio = audio dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate) duration = len(audio) / sampling_rate print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s") return (duration - dur_vad) > 0.5 def speaking(audio: np.ndarray, sampling_rate: int): audio_buffer = io.BytesIO() audio = AudioSegment( data.tobytes(), frame_rate=sampling_rate, sample_width=data.dtype.itemsize, channels=(1 if len(data.shape) == 1 else data.shape[1]), ) file = audio.export(audio_buffer, format="wav") with open("input_audio.wav", "wb") as f: f.write(audio_buffer.getvalue()) audio_bytes = audio_buffer.getvalue() base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8") files = {"audio": base64_encoded} with requests.post(API_URL, json=files, stream=True) as response: try: for chunk in response.iter_content(chunk_size=OUT_CHUNK): if chunk: # Create an audio segment from the numpy array audio_segment = AudioSegment( chunk, frame_rate=OUT_RATE, sample_width=OUT_SAMPLE_WIDTH, channels=OUT_CHANNELS, ) # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality mp3_io = io.BytesIO() audio_segment.export(mp3_io, format="mp3", bitrate="320k") # Get the MP3 bytes mp3_bytes = mp3_io.getvalue() mp3_io.close() yield mp3_bytes except Exception as e: raise gr.Error(f"Error during audio streaming: {e}") @dataclass class AppState: stream: np.ndarray | None = None sampling_rate: int = 0 pause_detected: bool = False def process_audio(audio: tuple, state: AppState): if state.stream is None: state.stream = audio[1] state.sampling_rate = audio[0] else: state.stream = np.concatenate((state.stream, audio[1])) pause_detected = determine_pause(state.stream, state.sampling_rate) state.pause_detected = pause_detected if state.pause_detected: return gr.Audio(recording=False), state return None, state def response(state: AppState): if not state.pause_detected: return None, None, AppState() for mp3_bytes in speaking(state.stream, state.sampling_rate): yield None, mp3_bytes, state yield gr.Audio(recording=True), None, AppState() with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_audio = gr.Audio( label="Input Audio", sources="microphone", type="filepath" ) with gr.Column(): output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True) state = gr.State(value=AppState()) stream = input_audio.stream( process_audio, [input_audio, state], [input_audio, state], stream_every=0.5, time_limit=30, ) respond = inp.stop_recording( response, [state], [input_audio, output_audio, state] ) cancel = gr.Button("Stop Conversation", variant="stop") cancel.click(lambda: AppState(), None, [state], cancels=[respond]) demo.launch()