gpt-omni commited on
Commit
58227c7
1 Parent(s): 369b919
Files changed (1) hide show
  1. inference.py +4 -3
inference.py CHANGED
@@ -138,6 +138,7 @@ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
138
  return torch.stack([audio_feature, audio_feature]), stacked_inputids
139
 
140
 
 
141
  def load_audio(path):
142
  audio = whisper.load_audio(path)
143
  duration_ms = (len(audio) / 16000) * 1000
@@ -357,7 +358,7 @@ def load_model(ckpt_dir, device):
357
  config.post_adapter = False
358
 
359
  with fabric.init_module(empty_init=False):
360
- model = GPT(config)
361
 
362
  # model = fabric.setup(model)
363
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
@@ -401,8 +402,8 @@ class OmniInference:
401
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
402
  model = self.model
403
 
404
- with self.fabric.init_tensor():
405
- model.set_kv_cache(batch_size=2)
406
 
407
  mel, leng = load_audio(audio_path)
408
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
 
138
  return torch.stack([audio_feature, audio_feature]), stacked_inputids
139
 
140
 
141
+ @spaces.GPU
142
  def load_audio(path):
143
  audio = whisper.load_audio(path)
144
  duration_ms = (len(audio) / 16000) * 1000
 
358
  config.post_adapter = False
359
 
360
  with fabric.init_module(empty_init=False):
361
+ model = GPT(config, device=device)
362
 
363
  # model = fabric.setup(model)
364
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
 
402
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
403
  model = self.model
404
 
405
+ # with self.fabric.init_tensor():
406
+ model.set_kv_cache(batch_size=2)
407
 
408
  mel, leng = load_audio(audio_path)
409
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)