gpt-omni commited on
Commit
9f64e91
1 Parent(s): 5f2e2de

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +1 -1
inference.py CHANGED
@@ -355,7 +355,7 @@ def load_model(ckpt_dir, device):
355
  config.post_adapter = False
356
 
357
  with fabric.init_module(empty_init=False):
358
- model = GPT(config)
359
 
360
  # model = fabric.setup(model)
361
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
 
355
  config.post_adapter = False
356
 
357
  with fabric.init_module(empty_init=False):
358
+ model = GPT(config, device=device)
359
 
360
  # model = fabric.setup(model)
361
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")