jadechoghari commited on
Commit
0370002
1 Parent(s): f1efba5

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +10 -10
modeling.py CHANGED
@@ -110,16 +110,16 @@ class MARModel(PreTrainedModel):
110
  # call the sample_tokens method from the MAR class
111
  return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
112
 
113
- @classmethod
114
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
115
- config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
116
- model = cls(config)
117
- safetensors_path = os.path.join(pretrained_model_name_or_path, "checkpoint-last.safetensors")
118
- if not os.path.exists(safetensors_path):
119
- raise FileNotFoundError(f"safetensors file not found at {safetensors_path}")
120
- state_dict = torch.load(safetensors_path, map_location='cpu')
121
- model.model.load_state_dict(state_dict)
122
- return model
123
 
124
 
125
  def save_pretrained(self, save_directory):
 
110
  # call the sample_tokens method from the MAR class
111
  return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
112
 
113
+ # @classmethod
114
+ # def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
115
+ # config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
116
+ # model = cls(config)
117
+ # safetensors_path = os.path.join(pretrained_model_name_or_path, "checkpoint-last.safetensors")
118
+ # if not os.path.exists(safetensors_path):
119
+ # raise FileNotFoundError(f"safetensors file not found at {safetensors_path}")
120
+ # state_dict = torch.load(safetensors_path, map_location='cpu')
121
+ # model.model.load_state_dict(state_dict)
122
+ # return model
123
 
124
 
125
  def save_pretrained(self, save_directory):