gpt-omni commited on
Commit
399ac1f
1 Parent(s): 8696667

Upload 3 files

Browse files
Files changed (1) hide show
  1. inference.py +19 -21
inference.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import lightning as L
3
  import torch
4
  import time
5
- import spaces
6
  from snac import SNAC
7
  from litgpt import Tokenizer
8
  from litgpt.utils import (
@@ -147,8 +146,8 @@ def load_audio(path):
147
 
148
  def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
149
  snacmodel, out_dir=None):
150
-
151
- model.set_kv_cache(batch_size=2)
152
  tokenlist = generate_TA_BATCH(
153
  model,
154
  audio_feature,
@@ -191,8 +190,8 @@ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, s
191
 
192
 
193
  def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
194
-
195
- model.set_kv_cache(batch_size=1)
196
  tokenlist = generate_AT(
197
  model,
198
  audio_feature,
@@ -214,8 +213,8 @@ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
214
 
215
  def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
216
  snacmodel, out_dir=None):
217
-
218
- model.set_kv_cache(batch_size=1)
219
  tokenlist = generate_AA(
220
  model,
221
  audio_feature,
@@ -256,8 +255,8 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
256
 
257
 
258
  def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
259
-
260
- model.set_kv_cache(batch_size=1)
261
  tokenlist = generate_ASR(
262
  model,
263
  audio_feature,
@@ -280,8 +279,8 @@ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
280
 
281
  def T1_A2(fabric, input_ids, model, text_tokenizer, step,
282
  snacmodel, out_dir=None):
283
-
284
- model.set_kv_cache(batch_size=1)
285
  tokenlist = generate_TA(
286
  model,
287
  None,
@@ -325,8 +324,8 @@ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
325
 
326
  def T1_T2(fabric, input_ids, model, text_tokenizer, step):
327
 
328
-
329
- model.set_kv_cache(batch_size=1)
330
  tokenlist = generate_TT(
331
  model,
332
  None,
@@ -356,13 +355,12 @@ def load_model(ckpt_dir, device):
356
  config.post_adapter = False
357
 
358
  with fabric.init_module(empty_init=False):
359
- model = GPT(config, device=device)
360
 
361
- # model = fabric.setup(model)
362
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
363
  model.load_state_dict(state_dict, strict=True)
364
- model = model.to(device)
365
- model.eval()
366
 
367
  return fabric, model, text_tokenizer, snacmodel, whispermodel
368
 
@@ -385,8 +383,7 @@ class OmniInference:
385
  for _ in self.run_AT_batch_stream(sample):
386
  pass
387
 
388
- # @torch.inference_mode()
389
- @spaces.GPU
390
  def run_AT_batch_stream(self,
391
  audio_path,
392
  stream_stride=4,
@@ -401,7 +398,8 @@ class OmniInference:
401
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
402
  model = self.model
403
 
404
- model.set_kv_cache(batch_size=2)
 
405
 
406
  mel, leng = load_audio(audio_path)
407
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
@@ -419,7 +417,7 @@ class OmniInference:
419
  list_output = [[] for i in range(8)]
420
  tokens_A, token_T = next_token_batch(
421
  model,
422
- audio_feature.to(torch.float32).to(device),
423
  input_ids,
424
  [T - 3, T - 3],
425
  ["A1T2", "A1T2"],
 
2
  import lightning as L
3
  import torch
4
  import time
 
5
  from snac import SNAC
6
  from litgpt import Tokenizer
7
  from litgpt.utils import (
 
146
 
147
  def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
148
  snacmodel, out_dir=None):
149
+ with fabric.init_tensor():
150
+ model.set_kv_cache(batch_size=2)
151
  tokenlist = generate_TA_BATCH(
152
  model,
153
  audio_feature,
 
190
 
191
 
192
  def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
193
+ with fabric.init_tensor():
194
+ model.set_kv_cache(batch_size=1)
195
  tokenlist = generate_AT(
196
  model,
197
  audio_feature,
 
213
 
214
  def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
215
  snacmodel, out_dir=None):
216
+ with fabric.init_tensor():
217
+ model.set_kv_cache(batch_size=1)
218
  tokenlist = generate_AA(
219
  model,
220
  audio_feature,
 
255
 
256
 
257
  def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
258
+ with fabric.init_tensor():
259
+ model.set_kv_cache(batch_size=1)
260
  tokenlist = generate_ASR(
261
  model,
262
  audio_feature,
 
279
 
280
  def T1_A2(fabric, input_ids, model, text_tokenizer, step,
281
  snacmodel, out_dir=None):
282
+ with fabric.init_tensor():
283
+ model.set_kv_cache(batch_size=1)
284
  tokenlist = generate_TA(
285
  model,
286
  None,
 
324
 
325
  def T1_T2(fabric, input_ids, model, text_tokenizer, step):
326
 
327
+ with fabric.init_tensor():
328
+ model.set_kv_cache(batch_size=1)
329
  tokenlist = generate_TT(
330
  model,
331
  None,
 
355
  config.post_adapter = False
356
 
357
  with fabric.init_module(empty_init=False):
358
+ model = GPT(config)
359
 
360
+ model = fabric.setup(model)
361
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
362
  model.load_state_dict(state_dict, strict=True)
363
+ model.to(device).eval()
 
364
 
365
  return fabric, model, text_tokenizer, snacmodel, whispermodel
366
 
 
383
  for _ in self.run_AT_batch_stream(sample):
384
  pass
385
 
386
+ @torch.inference_mode()
 
387
  def run_AT_batch_stream(self,
388
  audio_path,
389
  stream_stride=4,
 
398
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
399
  model = self.model
400
 
401
+ with self.fabric.init_tensor():
402
+ model.set_kv_cache(batch_size=2)
403
 
404
  mel, leng = load_audio(audio_path)
405
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
 
417
  list_output = [[] for i in range(8)]
418
  tokens_A, token_T = next_token_batch(
419
  model,
420
+ audio_feature.to(torch.float32).to(model.device),
421
  input_ids,
422
  [T - 3, T - 3],
423
  ["A1T2", "A1T2"],