S3TVR-Demo / main.py
yalsaffar's picture
init
aa7cb02
raw
history blame contribute delete
No virus
2.77 kB
import streamlit as st
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
import av
import numpy as np
import pydub
from io import BytesIO
from models.nllb import nllb
from models.parakeet import parakeet_ctc_model
from stream_VAD import stream
from models.es_fastconformer import stt_es_model
RTC_CONFIGURATION = RTCConfiguration({"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
# Load models once
model_nllb, tokenizer_nllb = nllb()
parakeet = parakeet_ctc_model()
stt_model = stt_es_model()
def process_audio(audio_chunk, language):
# Convert audio chunk to pydub.AudioSegment
audio_segment = pydub.AudioSegment(
data=audio_chunk.tobytes(),
sample_width=audio_chunk.format.sample_width,
frame_rate=audio_chunk.sample_rate,
channels=len(audio_chunk.layout.channels)
)
# Process audio based on selected language
if language == "en":
processed_audio = stream(parakeet, model_nllb, tokenizer_nllb, "english", "spanish", audio_segment)
elif language == "es":
processed_audio = stream(stt_model, model_nllb, tokenizer_nllb, "spanish", "english", audio_segment)
else:
return audio_chunk
# Convert processed audio back to numpy array
processed_audio_np = np.array(processed_audio.get_array_of_samples())
return processed_audio.frame_rate, processed_audio_np
def audio_callback(frame: av.AudioFrame, language):
audio_data = frame.to_ndarray()
audio_chunk = av.AudioFrame.from_ndarray(audio_data, format="s16", layout="mono")
return process_audio(audio_chunk, language)
st.title("Real-Time Audio Processing")
language = st.radio("Select Language", ["en", "es"], index=0)
webrtc_ctx = webrtc_streamer(
key="audio",
mode=WebRtcMode.SENDRECV,
rtc_configuration=RTC_CONFIGURATION,
media_stream_constraints={"audio": True, "video": False},
audio_receiver_size=256,
async_processing=True,
)
if webrtc_ctx.audio_receiver:
webrtc_ctx.audio_receiver.on("data", lambda frame: audio_callback(frame, language))
if "audio_buffer" not in st.session_state:
st.session_state["audio_buffer"] = BytesIO()
if webrtc_ctx.audio_receiver:
audio_frames = webrtc_ctx.audio_receiver.get_frames()
for frame in audio_frames:
processed_audio_rate, processed_audio_np = audio_callback(frame, language)
audio_segment = pydub.AudioSegment(
data=processed_audio_np.tobytes(),
sample_width=processed_audio_np.dtype.itemsize,
frame_rate=processed_audio_rate,
channels=1
)
st.session_state["audio_buffer"].write(audio_segment.export(format="wav").read())
st.audio(st.session_state["audio_buffer"].getvalue(), format="audio/wav")