Wuvin commited on
Commit
8d1277c
1 Parent(s): 8cb0437
Files changed (1) hide show
  1. gradio_app/custom_models/utils.py +1 -1
gradio_app/custom_models/utils.py CHANGED
@@ -59,7 +59,7 @@ def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight
59
  shared_modules = trainer.init_shared_modules(shared_modules)
60
 
61
  if load_from_checkpoint is not None:
62
- state_dict = torch.load(load_from_checkpoint)
63
  configurable_unet.unet.load_state_dict(state_dict, strict=False)
64
  # Move unet, vae and text_encoder to device and cast to weight_dtype
65
  configurable_unet.unet.to(device, dtype=weight_dtype)
 
59
  shared_modules = trainer.init_shared_modules(shared_modules)
60
 
61
  if load_from_checkpoint is not None:
62
+ state_dict = torch.load(load_from_checkpoint, map_location="cpu")
63
  configurable_unet.unet.load_state_dict(state_dict, strict=False)
64
  # Move unet, vae and text_encoder to device and cast to weight_dtype
65
  configurable_unet.unet.to(device, dtype=weight_dtype)