gpt-omni commited on
Commit
369b919
1 Parent(s): a4183a1
Files changed (1) hide show
  1. inference.py +3 -2
inference.py CHANGED
@@ -359,10 +359,11 @@ def load_model(ckpt_dir, device):
359
  with fabric.init_module(empty_init=False):
360
  model = GPT(config)
361
 
362
- model = fabric.setup(model)
363
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
364
  model.load_state_dict(state_dict, strict=True)
365
- model.to(device).eval()
 
366
 
367
  return fabric, model, text_tokenizer, snacmodel, whispermodel
368
 
 
359
  with fabric.init_module(empty_init=False):
360
  model = GPT(config)
361
 
362
+ # model = fabric.setup(model)
363
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
364
  model.load_state_dict(state_dict, strict=True)
365
+ model = model.to(device)
366
+ model.eval()
367
 
368
  return fabric, model, text_tokenizer, snacmodel, whispermodel
369