jadechoghari
commited on
Commit
•
f1efba5
1
Parent(s):
b4c4545
Update modeling.py
Browse files- modeling.py +5 -1
modeling.py
CHANGED
@@ -114,10 +114,14 @@ class MARModel(PreTrainedModel):
|
|
114 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
115 |
config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
116 |
model = cls(config)
|
117 |
-
|
|
|
|
|
|
|
118 |
model.model.load_state_dict(state_dict)
|
119 |
return model
|
120 |
|
|
|
121 |
def save_pretrained(self, save_directory):
|
122 |
# we will save to safetensors
|
123 |
os.makedirs(save_directory, exist_ok=True)
|
|
|
114 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
115 |
config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
116 |
model = cls(config)
|
117 |
+
safetensors_path = os.path.join(pretrained_model_name_or_path, "checkpoint-last.safetensors")
|
118 |
+
if not os.path.exists(safetensors_path):
|
119 |
+
raise FileNotFoundError(f"safetensors file not found at {safetensors_path}")
|
120 |
+
state_dict = torch.load(safetensors_path, map_location='cpu')
|
121 |
model.model.load_state_dict(state_dict)
|
122 |
return model
|
123 |
|
124 |
+
|
125 |
def save_pretrained(self, save_directory):
|
126 |
# we will save to safetensors
|
127 |
os.makedirs(save_directory, exist_ok=True)
|