jadechoghari commited on
Commit
f1efba5
1 Parent(s): b4c4545

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +5 -1
modeling.py CHANGED
@@ -114,10 +114,14 @@ class MARModel(PreTrainedModel):
114
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
115
  config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
116
  model = cls(config)
117
- state_dict = torch.load('checkpoint-last.safetensors')
 
 
 
118
  model.model.load_state_dict(state_dict)
119
  return model
120
 
 
121
  def save_pretrained(self, save_directory):
122
  # we will save to safetensors
123
  os.makedirs(save_directory, exist_ok=True)
 
114
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
115
  config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
116
  model = cls(config)
117
+ safetensors_path = os.path.join(pretrained_model_name_or_path, "checkpoint-last.safetensors")
118
+ if not os.path.exists(safetensors_path):
119
+ raise FileNotFoundError(f"safetensors file not found at {safetensors_path}")
120
+ state_dict = torch.load(safetensors_path, map_location='cpu')
121
  model.model.load_state_dict(state_dict)
122
  return model
123
 
124
+
125
  def save_pretrained(self, save_directory):
126
  # we will save to safetensors
127
  os.makedirs(save_directory, exist_ok=True)