import gradio as gr from huggingface_hub import snapshot_download from threading import Thread import time import base64 import numpy as np import requests import traceback from dataclasses import dataclass from pathlib import Path import io import wave import tempfile from pydub import AudioSegment import librosa from utils.vad import get_speech_timestamps, collect_chunks, VadOptions from server import serve repo_id = "gpt-omni/mini-omni" snapshot_download(repo_id, local_dir="./checkpoint", revision="main") IP = "0.0.0.0" PORT = 60808 thread = Thread(target=serve, daemon=True) thread.start() API_URL = "http://0.0.0.0:60808/chat" # recording parameters IN_CHANNELS = 1 IN_RATE = 24000 IN_CHUNK = 1024 IN_SAMPLE_WIDTH = 2 VAD_STRIDE = 0.5 # playing parameters OUT_CHANNELS = 1 OUT_RATE = 24000 OUT_SAMPLE_WIDTH = 2 OUT_CHUNK = 5760 OUT_CHUNK = 20 * 4096 OUT_RATE = 24000 OUT_CHANNELS = 1 def run_vad(ori_audio, sr): _st = time.time() try: audio = np.frombuffer(ori_audio, dtype=np.int16) audio = audio.astype(np.float32) / 32768.0 sampling_rate = 16000 if sr != sampling_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) vad_parameters = {} vad_parameters = VadOptions(**vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) audio = collect_chunks(audio, speech_chunks) duration_after_vad = audio.shape[0] / sampling_rate if sr != sampling_rate: # resample to original sampling rate vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) else: vad_audio = audio vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) vad_audio_bytes = vad_audio.tobytes() return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) except Exception as e: msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}" print(msg) return -1, ori_audio, round(time.time() - _st, 4) def warm_up(): frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each dur, frames, tcost = run_vad(frames, 16000) print(f"warm up done, time_cost: {tcost:.3f} s") warm_up() def determine_pause(stream: bytes, start_talking: bool) -> tuple[bool, bool]: """Take in the stream, determine if a pause happened""" temp_audio = stream if len(temp_audio) > IN_SAMPLE_WIDTH * IN_RATE * IN_CHANNELS * VAD_STRIDE: dur_vad, _, time_vad = run_vad(temp_audio, IN_RATE) print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s") if dur_vad > 0.2 and not start_talking: start_talking = True pause = False return pause, start_talking if dur_vad < 0.1 and start_talking: print("pause detected") return True, start_talking return False, start_talking return False, start_talking def speaking(total_frames: bytes): audio_buffer = io.BytesIO() wf = wave.open(audio_buffer, "wb") wf.setnchannels(IN_CHANNELS) wf.setsampwidth(IN_SAMPLE_WIDTH) wf.setframerate(IN_RATE) dur = len(total_frames) / (IN_RATE * IN_CHANNELS * IN_SAMPLE_WIDTH) print(f"Speaking... recorded audio duration: {dur:.3f} s") wf.writeframes(total_frames) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: with open(tmpfile.name, "wb") as f: f.write(audio_buffer.getvalue()) audio_bytes = audio_buffer.getvalue() base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8") files = {"audio": base64_encoded} with requests.post(API_URL, json=files, stream=True) as response: try: for chunk in response.iter_content(chunk_size=OUT_CHUNK): if chunk: # Create an audio segment from the numpy array audio_segment = AudioSegment( chunk, frame_rate=OUT_RATE, sample_width=OUT_SAMPLE_WIDTH, channels=OUT_CHANNELS, ) # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality mp3_io = io.BytesIO() audio_segment.export(mp3_io, format="mp3", bitrate="320k") # Get the MP3 bytes mp3_bytes = mp3_io.getvalue() mp3_io.close() yield mp3_bytes except Exception as e: raise gr.Error(f"Error during audio streaming: {e}") wf.close() @dataclass class AppState: start_talking: bool = False stream: bytes = b"" pause_detected: bool = False def process_audio(audio: str, state: AppState): state.stream += Path(audio).read_bytes() pause_detected, start_talking = determine_pause(state.stream, state.pause_detected) state.pause_detected = pause_detected state.start_talking = start_talking if not state.pause_detected: yield None, state for out_bytes in speaking(state.stream): yield out_bytes, state state = AppState() yield None, state with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_audio = gr.Audio( label="Input Audio", sources="microphone", type="filepath" ) with gr.Column(): output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True) state = gr.State(value=AppState()) input_audio.stop_recording( process_audio, [input_audio, state], [output_audio, state], stream_every=0.5, time_limit=30, ) demo.launch()