freddyaboulton HF staff commited on
Commit
2084afa
1 Parent(s): c4d6bf6
Files changed (2) hide show
  1. app.py +48 -44
  2. requirements.txt +1 -1
app.py CHANGED
@@ -51,7 +51,7 @@ OUT_CHANNELS = 1
51
  def run_vad(ori_audio, sr):
52
  _st = time.time()
53
  try:
54
- audio = np.frombuffer(ori_audio, dtype=np.int16)
55
  audio = audio.astype(np.float32) / 32768.0
56
  sampling_rate = 16000
57
  if sr != sampling_rate:
@@ -87,42 +87,32 @@ def warm_up():
87
  warm_up()
88
 
89
 
90
- def determine_pause(stream: bytes, start_talking: bool) -> tuple[bool, bool]:
91
  """Take in the stream, determine if a pause happened"""
92
 
93
- temp_audio = stream
 
 
 
94
 
95
- if len(temp_audio) > IN_SAMPLE_WIDTH * IN_RATE * IN_CHANNELS * VAD_STRIDE:
96
- dur_vad, _, time_vad = run_vad(temp_audio, IN_RATE)
97
 
98
- print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
99
 
100
- if dur_vad > 0.2 and not start_talking:
101
- start_talking = True
102
- pause = False
103
- return pause, start_talking
104
- if dur_vad < 0.1 and start_talking:
105
- print("pause detected")
106
- return True, start_talking
107
- return False, start_talking
108
- return False, start_talking
109
 
110
-
111
- def speaking(total_frames: bytes):
112
  audio_buffer = io.BytesIO()
113
- wf = wave.open(audio_buffer, "wb")
114
- wf.setnchannels(IN_CHANNELS)
115
- wf.setsampwidth(IN_SAMPLE_WIDTH)
116
- wf.setframerate(IN_RATE)
117
-
118
- dur = len(total_frames) / (IN_RATE * IN_CHANNELS * IN_SAMPLE_WIDTH)
119
- print(f"Speaking... recorded audio duration: {dur:.3f} s")
120
-
121
- wf.writeframes(total_frames)
122
 
123
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
124
- with open(tmpfile.name, "wb") as f:
125
- f.write(audio_buffer.getvalue())
126
 
127
  audio_bytes = audio_buffer.getvalue()
128
 
@@ -152,31 +142,38 @@ def speaking(total_frames: bytes):
152
  except Exception as e:
153
  raise gr.Error(f"Error during audio streaming: {e}")
154
 
155
- wf.close()
156
 
157
 
158
  @dataclass
159
  class AppState:
160
- start_talking: bool = False
161
- stream: bytes = b""
162
  pause_detected: bool = False
163
 
164
 
165
- def process_audio(audio: str, state: AppState):
166
- state.stream += Path(audio).read_bytes()
 
 
 
 
167
 
168
- pause_detected, start_talking = determine_pause(state.stream, state.pause_detected)
169
  state.pause_detected = pause_detected
170
- state.start_talking = start_talking
171
 
172
- if not state.pause_detected:
173
- yield None, state
 
174
 
175
- for out_bytes in speaking(state.stream):
176
- yield out_bytes, state
177
 
178
- state = AppState()
179
- yield None, state
 
 
 
 
 
 
180
 
181
 
182
  with gr.Blocks() as demo:
@@ -189,13 +186,20 @@ with gr.Blocks() as demo:
189
  output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)
190
  state = gr.State(value=AppState())
191
 
192
- input_audio.stop_recording(
193
  process_audio,
194
  [input_audio, state],
195
- [output_audio, state],
196
  stream_every=0.5,
197
  time_limit=30,
198
  )
 
 
 
 
 
 
 
199
 
200
 
201
  demo.launch()
 
51
  def run_vad(ori_audio, sr):
52
  _st = time.time()
53
  try:
54
+ audio = ori_audio
55
  audio = audio.astype(np.float32) / 32768.0
56
  sampling_rate = 16000
57
  if sr != sampling_rate:
 
87
  warm_up()
88
 
89
 
90
+ def determine_pause(audio: np.ndarray, sampling_rate: int) -> bool:
91
  """Take in the stream, determine if a pause happened"""
92
 
93
+ temp_audio = audio
94
+
95
+ dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
96
+ duration = len(audio) / sampling_rate
97
 
98
+ print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
 
99
 
100
+ return (duration - dur_vad) > 0.5
101
 
 
 
 
 
 
 
 
 
 
102
 
103
+ def speaking(audio: np.ndarray, sampling_rate: int):
 
104
  audio_buffer = io.BytesIO()
105
+
106
+ audio = AudioSegment(
107
+ data.tobytes(),
108
+ frame_rate=sampling_rate,
109
+ sample_width=data.dtype.itemsize,
110
+ channels=(1 if len(data.shape) == 1 else data.shape[1]),
111
+ )
112
+ file = audio.export(audio_buffer, format="wav")
 
113
 
114
+ with open("input_audio.wav", "wb") as f:
115
+ f.write(audio_buffer.getvalue())
 
116
 
117
  audio_bytes = audio_buffer.getvalue()
118
 
 
142
  except Exception as e:
143
  raise gr.Error(f"Error during audio streaming: {e}")
144
 
 
145
 
146
 
147
  @dataclass
148
  class AppState:
149
+ stream: np.ndarray | None = None
150
+ sampling_rate: int = 0
151
  pause_detected: bool = False
152
 
153
 
154
+ def process_audio(audio: tuple, state: AppState):
155
+ if state.stream is None:
156
+ state.stream = audio[1]
157
+ state.sampling_rate = audio[0]
158
+ else:
159
+ state.stream = np.concatenate((state.stream, audio[1]))
160
 
161
+ pause_detected = determine_pause(state.stream, state.sampling_rate)
162
  state.pause_detected = pause_detected
 
163
 
164
+ if state.pause_detected:
165
+ return gr.Audio(recording=False), state
166
+ return None, state
167
 
 
 
168
 
169
+ def response(state: AppState):
170
+ if not state.pause_detected:
171
+ return None, None, AppState()
172
+
173
+ for mp3_bytes in speaking(state.stream, state.sampling_rate):
174
+ yield None, mp3_bytes, state
175
+
176
+ yield gr.Audio(recording=True), None, AppState()
177
 
178
 
179
  with gr.Blocks() as demo:
 
186
  output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)
187
  state = gr.State(value=AppState())
188
 
189
+ stream = input_audio.stream(
190
  process_audio,
191
  [input_audio, state],
192
+ [input_audio, state],
193
  stream_every=0.5,
194
  time_limit=30,
195
  )
196
+ respond = inp.stop_recording(
197
+ response,
198
+ [state],
199
+ [input_audio, output_audio, state]
200
+ )
201
+ cancel = gr.Button("Stop Conversation", variant="stop")
202
+ cancel.click(lambda: AppState(), None, [state], cancels=[respond])
203
 
204
 
205
  demo.launch()
requirements.txt CHANGED
@@ -11,7 +11,7 @@ streamlit==1.37.1
11
  pydub==0.25.1
12
  onnxruntime==1.19.0
13
  # numpy==1.26.3
14
- https://gradio-builds.s3.amazonaws.com/5.0-dev/e2157efe20cdec2454b0b5d312fad00b2b5bfe1c/gradio-5.0.0b1-py3-none-any.whl
15
  fastapi==0.112.4
16
  librosa==0.10.2.post1
17
  flask==3.0.3
 
11
  pydub==0.25.1
12
  onnxruntime==1.19.0
13
  # numpy==1.26.3
14
+ https://gradio-builds.s3.amazonaws.com/e3011b3b19ee8f7b7fc2dbba848d56a0b30b6cdb/gradio-5.0.0b1-py3-none-any.whl
15
  fastapi==0.112.4
16
  librosa==0.10.2.post1
17
  flask==3.0.3