freddyaboulton HF staff commited on
Commit
4e9b286
1 Parent(s): 5e3f570
Files changed (2) hide show
  1. app.py +125 -109
  2. requirements.txt +1 -4
app.py CHANGED
@@ -1,18 +1,26 @@
1
- import gradio as gr
2
- from huggingface_hub import snapshot_download
3
- from threading import Thread
4
- import time
5
  import base64
6
- import numpy as np
7
- import requests
 
8
  import traceback
9
  from dataclasses import dataclass, field
10
- import io
 
 
 
 
 
 
 
 
11
  from pydub import AudioSegment
12
  import librosa
13
  from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
14
  import tempfile
15
 
 
 
 
16
 
17
  from server import serve
18
 
@@ -22,11 +30,15 @@ snapshot_download(repo_id, local_dir="./checkpoint", revision="main")
22
  IP = "0.0.0.0"
23
  PORT = 60808
24
 
25
- serve(port=7860)
 
26
 
27
 
28
  API_URL = "http://0.0.0.0:60808/chat"
29
 
 
 
 
30
  # recording parameters
31
  IN_CHANNELS = 1
32
  IN_RATE = 24000
@@ -38,12 +50,7 @@ VAD_STRIDE = 0.5
38
  OUT_CHANNELS = 1
39
  OUT_RATE = 24000
40
  OUT_SAMPLE_WIDTH = 2
41
- OUT_CHUNK = 5760
42
-
43
-
44
  OUT_CHUNK = 20 * 4096
45
- OUT_RATE = 24000
46
- OUT_CHANNELS = 1
47
 
48
 
49
  def run_vad(ori_audio, sr):
@@ -77,94 +84,109 @@ def run_vad(ori_audio, sr):
77
 
78
 
79
  def warm_up():
80
- frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each
81
- dur, frames, tcost = run_vad(frames, 16000)
82
  print(f"warm up done, time_cost: {tcost:.3f} s")
83
 
84
 
85
  warm_up()
86
 
87
-
88
  @dataclass
89
  class AppState:
90
  stream: np.ndarray | None = None
91
  sampling_rate: int = 0
92
  pause_detected: bool = False
93
- started_talking: bool = False
 
94
  stopped: bool = False
95
- conversation: list = field(default_factory=list)
 
96
 
97
 
98
  def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
99
  """Take in the stream, determine if a pause happened"""
100
-
101
- temp_audio = audio
102
-
103
- dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
104
  duration = len(audio) / sampling_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- if dur_vad > 0.5 and not state.started_talking:
107
- print("started talking")
108
- state.started_talking = True
109
- return False
110
-
111
- print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
112
-
113
- return (duration - dur_vad) > 1
114
 
115
 
116
  def speaking(audio_bytes: str):
117
 
118
  base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
119
  files = {"audio": base64_encoded}
 
120
  with requests.post(API_URL, json=files, stream=True) as response:
121
  try:
122
  for chunk in response.iter_content(chunk_size=OUT_CHUNK):
123
  if chunk:
124
  # Create an audio segment from the numpy array
 
125
  audio_segment = AudioSegment(
126
- chunk,
127
  frame_rate=OUT_RATE,
128
  sample_width=OUT_SAMPLE_WIDTH,
129
  channels=OUT_CHANNELS,
130
  )
131
-
132
- # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
133
- mp3_io = io.BytesIO()
134
- audio_segment.export(mp3_io, format="mp3", bitrate="320k")
135
-
136
- # Get the MP3 bytes
137
- mp3_bytes = mp3_io.getvalue()
138
- mp3_io.close()
139
- yield mp3_bytes
140
-
 
 
141
  except Exception as e:
142
  raise gr.Error(f"Error during audio streaming: {e}")
143
 
144
 
145
 
146
-
147
- def process_audio(audio: tuple, state: AppState):
148
- if state.stream is None:
149
- state.stream = audio[1]
150
- state.sampling_rate = audio[0]
 
 
151
  else:
152
- state.stream = np.concatenate((state.stream, audio[1]))
153
 
154
- pause_detected = determine_pause(state.stream, state.sampling_rate, state)
155
  state.pause_detected = pause_detected
156
 
157
- if state.pause_detected and state.started_talking:
158
- return gr.Audio(recording=False), state
159
- return None, state
160
-
161
 
162
  def response(state: AppState):
163
  if not state.pause_detected and not state.started_talking:
164
- return None, AppState()
165
 
166
  audio_buffer = io.BytesIO()
167
-
168
  segment = AudioSegment(
169
  state.stream.tobytes(),
170
  frame_rate=state.sampling_rate,
@@ -172,68 +194,62 @@ def response(state: AppState):
172
  channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
173
  )
174
  segment.export(audio_buffer, format="wav")
175
-
176
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
177
- f.write(audio_buffer.getvalue())
178
 
179
- state.conversation.append({"role": "user",
180
- "content": {"path": f.name,
181
- "mime_type": "audio/wav"}})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- output_buffer = b""
184
-
185
- for mp3_bytes in speaking(audio_buffer.getvalue()):
186
- output_buffer += mp3_bytes
187
- yield mp3_bytes, state
188
-
189
- with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
190
- f.write(output_buffer)
191
-
192
- state.conversation.append({"role": "assistant",
193
- "content": {"path": f.name,
194
- "mime_type": "audio/mp3"}})
195
- yield None, AppState(conversation=state.conversation)
196
-
197
-
198
-
199
 
200
- def start_recording_user(state: AppState):
201
- if not state.stopped:
202
- return gr.Audio(recording=True)
203
 
204
  with gr.Blocks() as demo:
205
- with gr.Row():
206
- with gr.Column():
207
- input_audio = gr.Audio(
208
- label="Input Audio", sources="microphone", type="numpy"
209
- )
210
- with gr.Column():
211
- chatbot = gr.Chatbot(label="Conversation", type="messages")
212
- output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)
213
- state = gr.State(value=AppState())
214
-
215
- stream = input_audio.stream(
216
- process_audio,
217
- [input_audio, state],
218
- [input_audio, state],
219
- stream_every=0.50,
220
- time_limit=30,
221
- )
222
- respond = input_audio.stop_recording(
223
- response,
224
- [state],
225
- [output_audio, state]
226
  )
227
- respond.then(lambda s: s.conversation, [state], [chatbot])
 
 
 
 
 
 
 
 
 
228
 
229
- restart = output_audio.stop(
230
- start_recording_user,
231
- [state],
232
- [input_audio]
233
- )
234
- cancel = gr.Button("Stop Conversation", variant="stop")
235
- cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None,
236
- [state, input_audio], cancels=[respond, restart])
237
 
238
 
239
  demo.launch()
 
 
 
 
 
1
  import base64
2
+ import io
3
+ import tempfile
4
+ import time
5
  import traceback
6
  from dataclasses import dataclass, field
7
+ from queue import Queue
8
+ from threading import Thread, Event
9
+
10
+ import gradio as gr
11
+ import librosa
12
+ import numpy as np
13
+ import requests
14
+ from gradio_webrtc import StreamHandler, WebRTC
15
+ from huggingface_hub import snapshot_download
16
  from pydub import AudioSegment
17
  import librosa
18
  from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
19
  import tempfile
20
 
21
+ # from server import serve
22
+ from utils.vad import VadOptions, collect_chunks, get_speech_timestamps
23
+
24
 
25
  from server import serve
26
 
 
30
  IP = "0.0.0.0"
31
  PORT = 60808
32
 
33
+ thread = Thread(target=serve, daemon=True)
34
+ thread.start()
35
 
36
 
37
  API_URL = "http://0.0.0.0:60808/chat"
38
 
39
+
40
+ #API_URL = "https://freddyaboulton-omni-backend.hf.space/chat"
41
+
42
  # recording parameters
43
  IN_CHANNELS = 1
44
  IN_RATE = 24000
 
50
  OUT_CHANNELS = 1
51
  OUT_RATE = 24000
52
  OUT_SAMPLE_WIDTH = 2
 
 
 
53
  OUT_CHUNK = 20 * 4096
 
 
54
 
55
 
56
  def run_vad(ori_audio, sr):
 
84
 
85
 
86
  def warm_up():
87
+ frames = np.zeros((1, 1600)) # 1024 frames of 2 bytes each
88
+ _, frames, tcost = run_vad(frames, 16000)
89
  print(f"warm up done, time_cost: {tcost:.3f} s")
90
 
91
 
92
  warm_up()
93
 
 
94
  @dataclass
95
  class AppState:
96
  stream: np.ndarray | None = None
97
  sampling_rate: int = 0
98
  pause_detected: bool = False
99
+ started_talking: bool = False
100
+ responding: bool = False
101
  stopped: bool = False
102
+ buffer: np.ndarray | None = None
103
+
104
 
105
 
106
  def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
107
  """Take in the stream, determine if a pause happened"""
 
 
 
 
108
  duration = len(audio) / sampling_rate
109
+
110
+ dur_vad, _, _ = run_vad(audio, sampling_rate)
111
+
112
+ if duration >= 0.60:
113
+ if dur_vad > 0.2 and not state.started_talking:
114
+ print("started talking")
115
+ state.started_talking = True
116
+ if state.started_talking:
117
+ if state.stream is None:
118
+ state.stream = audio
119
+ else:
120
+ state.stream = np.concatenate((state.stream, audio))
121
+ state.buffer = None
122
+ if dur_vad < 0.1 and state.started_talking:
123
+ segment = AudioSegment(
124
+ state.stream.tobytes(),
125
+ frame_rate=sampling_rate,
126
+ sample_width=audio.dtype.itemsize,
127
+ channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
128
+ )
129
 
130
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
131
+ segment.export(f.name, format="wav")
132
+ print("input file written", f.name)
133
+ return True
134
+ return False
 
 
 
135
 
136
 
137
  def speaking(audio_bytes: str):
138
 
139
  base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
140
  files = {"audio": base64_encoded}
141
+ byte_buffer = b""
142
  with requests.post(API_URL, json=files, stream=True) as response:
143
  try:
144
  for chunk in response.iter_content(chunk_size=OUT_CHUNK):
145
  if chunk:
146
  # Create an audio segment from the numpy array
147
+ byte_buffer += chunk
148
  audio_segment = AudioSegment(
149
+ chunk + b"\x00" if len(chunk) % 2 != 0 else chunk,
150
  frame_rate=OUT_RATE,
151
  sample_width=OUT_SAMPLE_WIDTH,
152
  channels=OUT_CHANNELS,
153
  )
154
+ # Export the audio segment to a numpy array
155
+ audio_np = np.array(audio_segment.get_array_of_samples())
156
+ yield audio_np.reshape(1, -1)
157
+ all_output_audio = AudioSegment(
158
+ byte_buffer,
159
+ frame_rate=OUT_RATE,
160
+ sample_width=OUT_SAMPLE_WIDTH,
161
+ channels=1,
162
+ )
163
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
164
+ all_output_audio.export(f.name, format="wav")
165
+ print("output file written", f.name)
166
  except Exception as e:
167
  raise gr.Error(f"Error during audio streaming: {e}")
168
 
169
 
170
 
171
+ def process_audio(audio: tuple, state: AppState) -> None:
172
+ frame_rate, array = audio
173
+ array = np.squeeze(array)
174
+ if not state.sampling_rate:
175
+ state.sampling_rate = frame_rate
176
+ if state.buffer is None:
177
+ state.buffer = array
178
  else:
179
+ state.buffer = np.concatenate((state.buffer, array))
180
 
181
+ pause_detected = determine_pause(state.buffer, state.sampling_rate, state)
182
  state.pause_detected = pause_detected
183
 
 
 
 
 
184
 
185
  def response(state: AppState):
186
  if not state.pause_detected and not state.started_talking:
187
+ return None
188
 
189
  audio_buffer = io.BytesIO()
 
190
  segment = AudioSegment(
191
  state.stream.tobytes(),
192
  frame_rate=state.sampling_rate,
 
194
  channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
195
  )
196
  segment.export(audio_buffer, format="wav")
 
 
 
197
 
198
+ for numpy_array in speaking(audio_buffer.getvalue()):
199
+ yield (OUT_RATE, numpy_array, "mono")
200
+
201
+
202
+ class OmniHandler(StreamHandler):
203
+ def __init__(self) -> None:
204
+ super().__init__(expected_layout="mono", output_sample_rate=OUT_RATE, output_frame_size=480)
205
+ self.chunk_queue = Queue()
206
+ self.state = AppState()
207
+ self.generator = None
208
+ self.duration = 0
209
+
210
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
211
+ if self.state.responding:
212
+ return
213
+ process_audio(frame, self.state)
214
+ if self.state.pause_detected:
215
+ self.chunk_queue.put(True)
216
 
217
+ def reset(self):
218
+ self.generator = None
219
+ self.state = AppState()
220
+ self.duration = 0
221
+
222
+ def emit(self):
223
+ if not self.generator:
224
+ self.chunk_queue.get()
225
+ self.state.responding = True
226
+ self.generator = response(self.state)
227
+ try:
228
+ return next(self.generator)
229
+ except StopIteration:
230
+ self.reset()
231
+
 
232
 
 
 
 
233
 
234
  with gr.Blocks() as demo:
235
+ gr.HTML(
236
+ """
237
+ <h1 style='text-align: center'>
238
+ Omni Chat (Powered by WebRTC ⚡️)
239
+ </h1>
240
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  )
242
+ with gr.Column():
243
+ with gr.Group():
244
+ audio = WebRTC(
245
+ label="Stream",
246
+ rtc_configuration=None,
247
+ mode="send-receive",
248
+ modality="audio",
249
+ )
250
+ audio.stream(fn=OmniHandler(), inputs=[audio], outputs=[audio], time_limit=300)
251
+
252
 
 
 
 
 
 
 
 
 
253
 
254
 
255
  demo.launch()
requirements.txt CHANGED
@@ -6,13 +6,10 @@ snac==1.2.0
6
  soundfile==0.12.1
7
  openai-whisper
8
  tokenizers==0.19.1
9
- streamlit==1.37.1
10
- # PyAudio==0.2.14
11
  pydub==0.25.1
12
  onnxruntime==1.19.0
13
- # numpy==1.26.3
14
- https://gradio-builds.s3.amazonaws.com/cffe9a7ab7f71e76d7214dc57c6278ffaf5bcdf9/gradio-5.0.0b1-py3-none-any.whl
15
  fastapi==0.112.4
16
  librosa==0.10.2.post1
17
  flask==3.0.3
18
  fire
 
 
6
  soundfile==0.12.1
7
  openai-whisper
8
  tokenizers==0.19.1
 
 
9
  pydub==0.25.1
10
  onnxruntime==1.19.0
 
 
11
  fastapi==0.112.4
12
  librosa==0.10.2.post1
13
  flask==3.0.3
14
  fire
15
+ https://gradio-builds.s3.us-east-1.amazonaws.com/webrtc/08/gradio_webrtc-0.0.5-py3-none-any.whl