Spaces:
Running
on
T4
Running
on
T4
import logging | |
# Configure the root logger to WARNING to suppress debug messages from other libraries | |
logging.basicConfig(level=logging.WARNING) | |
# Create a file handler instead of console handler | |
file_handler = logging.FileHandler("gradio_webrtc.log") | |
file_handler.setLevel(logging.DEBUG) | |
# Create a formatter (you might want to add timestamp to file logs) | |
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
file_handler.setFormatter(formatter) | |
# Configure the logger for your specific library | |
logger = logging.getLogger("gradio_webrtc") | |
logger.setLevel(logging.DEBUG) | |
logger.addHandler(file_handler) | |
import base64 | |
import io | |
import os | |
import tempfile | |
import time | |
import traceback | |
from dataclasses import dataclass | |
from threading import Event, Thread | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import requests | |
from gradio_webrtc import ReplyOnPause, WebRTC | |
from huggingface_hub import snapshot_download | |
from pydub import AudioSegment | |
from twilio.rest import Client | |
from server import serve | |
# from server import serve | |
from utils.vad import VadOptions, collect_chunks, get_speech_timestamps | |
repo_id = "gpt-omni/mini-omni" | |
snapshot_download(repo_id, local_dir="./checkpoint", revision="main") | |
IP = "0.0.0.0" | |
PORT = 60808 | |
thread = Thread(target=serve, daemon=True) | |
thread.start() | |
API_URL = "http://0.0.0.0:60808/chat" | |
account_sid = os.environ.get("TWILIO_ACCOUNT_SID") | |
auth_token = os.environ.get("TWILIO_AUTH_TOKEN") | |
if account_sid and auth_token: | |
client = Client(account_sid, auth_token) | |
token = client.tokens.create() | |
rtc_configuration = { | |
"iceServers": token.ice_servers, | |
"iceTransportPolicy": "relay", | |
} | |
else: | |
rtc_configuration = None | |
OUT_CHANNELS = 1 | |
OUT_RATE = 24000 | |
OUT_SAMPLE_WIDTH = 2 | |
OUT_CHUNK = 20 * 4096 | |
def speaking(audio_bytes: bytes): | |
base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8") | |
files = {"audio": base64_encoded} | |
byte_buffer = b"" | |
with requests.post(API_URL, json=files, stream=True) as response: | |
try: | |
for chunk in response.iter_content(chunk_size=OUT_CHUNK): | |
if chunk: | |
# Create an audio segment from the numpy array | |
byte_buffer += chunk | |
audio_segment = AudioSegment( | |
chunk + b"\x00" if len(chunk) % 2 != 0 else chunk, | |
frame_rate=OUT_RATE, | |
sample_width=OUT_SAMPLE_WIDTH, | |
channels=OUT_CHANNELS, | |
) | |
# Export the audio segment to a numpy array | |
audio_np = np.array(audio_segment.get_array_of_samples()) | |
yield audio_np.reshape(1, -1) | |
all_output_audio = AudioSegment( | |
byte_buffer, | |
frame_rate=OUT_RATE, | |
sample_width=OUT_SAMPLE_WIDTH, | |
channels=1, | |
) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
all_output_audio.export(f.name, format="wav") | |
print("output file written", f.name) | |
except Exception as e: | |
raise gr.Error(f"Error during audio streaming: {e}") | |
def response(audio: tuple[int, np.ndarray]): | |
sampling_rate, audio_np = audio | |
audio_np = audio_np.squeeze() | |
audio_buffer = io.BytesIO() | |
segment = AudioSegment( | |
audio_np.tobytes(), | |
frame_rate=sampling_rate, | |
sample_width=audio_np.dtype.itemsize, | |
channels=1) | |
segment.export(audio_buffer, format="wav") | |
for numpy_array in speaking(audio_buffer.getvalue()): | |
yield (OUT_RATE, numpy_array, "mono") | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<h1 style='text-align: center'> | |
Omni Chat (Powered by WebRTC ⚡️) | |
</h1> | |
""" | |
) | |
with gr.Column(): | |
with gr.Group(): | |
audio = WebRTC( | |
label="Stream", | |
rtc_configuration=rtc_configuration, | |
mode="send-receive", | |
modality="audio", | |
) | |
audio.stream(fn=ReplyOnPause(response), inputs=[audio], outputs=[audio], time_limit=60) | |
demo.launch(ssr_mode=False) | |