freddyaboulton's picture
Update app.py
bf0cc6a verified
import logging
import base64
import io
import os
from threading import Thread
import gradio as gr
import numpy as np
import requests
from gradio_webrtc import ReplyOnPause, WebRTC, AdditionalOutputs
from pydub import AudioSegment
from twilio.rest import Client
from server import serve
logging.basicConfig(level=logging.WARNING)
file_handler = logging.FileHandler("gradio_webrtc.log")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
IP = "0.0.0.0"
PORT = 60808
thread = Thread(target=serve, daemon=True)
thread.start()
API_URL = "http://0.0.0.0:60808/chat"
# Only needed if deploying on cloud provider
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
OUT_CHANNELS = 1
OUT_RATE = 24000
OUT_SAMPLE_WIDTH = 2
OUT_CHUNK = 20 * 4096
def response(audio: tuple[int, np.ndarray], conversation: list[dict], img: str | None):
sampling_rate, audio_np = audio
audio_np = audio_np.squeeze()
audio_buffer = io.BytesIO()
segment = AudioSegment(
audio_np.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_np.dtype.itemsize,
channels=1,
)
segment.export(audio_buffer, format="wav")
conversation.append({"role": "user", "content": gr.Audio((sampling_rate, audio_np))})
conversation.append({"role": "assistant", "content": ""})
base64_encoded = str(base64.b64encode(audio_buffer.getvalue()), encoding="utf-8")
if API_URL is not None:
output_audio_bytes = b""
files = {"audio": base64_encoded}
if img is not None:
files["image"] = str(base64.b64encode(open(img, "rb").read()), encoding="utf-8")
print("sending request to server")
resp_text = ""
with requests.post(API_URL, json=files, stream=True) as response:
try:
buffer = b''
for chunk in response.iter_content(chunk_size=2048):
buffer += chunk
while b'\r\n--frame\r\n' in buffer:
frame, buffer = buffer.split(b'\r\n--frame\r\n', 1)
if b'Content-Type: audio/wav' in frame:
audio_data = frame.split(b'\r\n\r\n', 1)[1]
# audio_data = base64.b64decode(audio_data)
output_audio_bytes += audio_data
audio_array = np.frombuffer(audio_data, dtype=np.int16).reshape(1, -1)
yield (OUT_RATE, audio_array, "mono")
elif b'Content-Type: text/plain' in frame:
text_data = frame.split(b'\r\n\r\n', 1)[1].decode()
resp_text += text_data
conversation[-1]["content"] = resp_text
yield AdditionalOutputs(conversation)
except Exception as e:
raise Exception(f"Error during audio streaming: {e}") from e
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Mini-Omni-2 Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="send-receive",
modality="audio",
)
with gr.Column():
img = gr.Image(label="Image", type="filepath")
with gr.Column():
conversation = gr.Chatbot(label="Conversation", type="messages")
audio.stream(
fn=ReplyOnPause(
response, output_sample_rate=OUT_RATE, output_frame_size=480
),
inputs=[audio, conversation, img],
outputs=[audio],
time_limit=90,
)
audio.on_additional_outputs(lambda c: c, outputs=[conversation])
demo.launch()