jadechoghari
commited on
Commit
•
0370002
1
Parent(s):
f1efba5
Update modeling.py
Browse files- modeling.py +10 -10
modeling.py
CHANGED
@@ -110,16 +110,16 @@ class MARModel(PreTrainedModel):
|
|
110 |
# call the sample_tokens method from the MAR class
|
111 |
return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
|
112 |
|
113 |
-
@classmethod
|
114 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
def save_pretrained(self, save_directory):
|
|
|
110 |
# call the sample_tokens method from the MAR class
|
111 |
return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
|
112 |
|
113 |
+
# @classmethod
|
114 |
+
# def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
115 |
+
# config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
116 |
+
# model = cls(config)
|
117 |
+
# safetensors_path = os.path.join(pretrained_model_name_or_path, "checkpoint-last.safetensors")
|
118 |
+
# if not os.path.exists(safetensors_path):
|
119 |
+
# raise FileNotFoundError(f"safetensors file not found at {safetensors_path}")
|
120 |
+
# state_dict = torch.load(safetensors_path, map_location='cpu')
|
121 |
+
# model.model.load_state_dict(state_dict)
|
122 |
+
# return model
|
123 |
|
124 |
|
125 |
def save_pretrained(self, save_directory):
|