gpt-omni commited on
Commit
7ba9b1d
1 Parent(s): e1adc1c
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -128,7 +128,7 @@ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
128
  stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
129
  return torch.stack([audio_feature, audio_feature]), stacked_inputids
130
 
131
-
132
  @spaces.GPU
133
  def next_token_batch(
134
  model: GPT,
@@ -156,7 +156,7 @@ def next_token_batch(
156
  next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
157
  return next_audio_tokens, next_t
158
 
159
-
160
  def load_audio(path):
161
  audio = whisper.load_audio(path)
162
  duration_ms = (len(audio) / 16000) * 1000
@@ -164,7 +164,7 @@ def load_audio(path):
164
  mel = whisper.log_mel_spectrogram(audio)
165
  return mel, int(duration_ms / 20) + 1
166
 
167
-
168
  @spaces.GPU
169
  def generate_audio_data(snac_tokens, snacmodel, device=None):
170
  audio = reconstruct_tensors(snac_tokens, device)
@@ -190,7 +190,7 @@ def run_AT_batch_stream(
190
 
191
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
192
 
193
- model.set_kv_cache(batch_size=2)
194
 
195
  mel, leng = load_audio(audio_path)
196
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
@@ -295,7 +295,7 @@ def run_AT_batch_stream(
295
  model.clear_kv_cache()
296
  return list_output
297
 
298
-
299
  for chunk in run_AT_batch_stream('./data/samples/output1.wav'):
300
  pass
301
 
@@ -326,4 +326,4 @@ demo = gr.Interface(
326
  # live=True,
327
  )
328
  demo.queue()
329
- demo.launch()
 
128
  stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
129
  return torch.stack([audio_feature, audio_feature]), stacked_inputids
130
 
131
+
132
  @spaces.GPU
133
  def next_token_batch(
134
  model: GPT,
 
156
  next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
157
  return next_audio_tokens, next_t
158
 
159
+
160
  def load_audio(path):
161
  audio = whisper.load_audio(path)
162
  duration_ms = (len(audio) / 16000) * 1000
 
164
  mel = whisper.log_mel_spectrogram(audio)
165
  return mel, int(duration_ms / 20) + 1
166
 
167
+
168
  @spaces.GPU
169
  def generate_audio_data(snac_tokens, snacmodel, device=None):
170
  audio = reconstruct_tensors(snac_tokens, device)
 
190
 
191
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
192
 
193
+ model.set_kv_cache(batch_size=2, device=device)
194
 
195
  mel, leng = load_audio(audio_path)
196
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
 
295
  model.clear_kv_cache()
296
  return list_output
297
 
298
+
299
  for chunk in run_AT_batch_stream('./data/samples/output1.wav'):
300
  pass
301
 
 
326
  # live=True,
327
  )
328
  demo.queue()
329
+ demo.launch()