carlosgomes98 commited on
Commit
67661a3
1 Parent(s): 9e63f80

remove timestep restriction

Browse files
Files changed (1) hide show
  1. Prithvi_run_inference.py +8 -4
Prithvi_run_inference.py CHANGED
@@ -252,7 +252,7 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
252
  params = yaml.safe_load(f)
253
 
254
  # data related
255
- num_frames = params['num_frames']
256
  img_size = params['img_size']
257
  bands = params['bands']
258
  mean = params['data_mean']
@@ -272,8 +272,9 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
272
 
273
  mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
274
 
275
- # We must have *num_frames* files to build one example!
276
- assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
 
277
 
278
  if torch.cuda.is_available():
279
  device = torch.device('cuda')
@@ -310,7 +311,10 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
310
  model.to(device)
311
 
312
  state_dict = torch.load(checkpoint, map_location=device)
313
- model.load_state_dict(state_dict)
 
 
 
314
  print(f"Loaded checkpoint from {checkpoint}")
315
 
316
  # Running model --------------------------------------------------------------------------------
 
252
  params = yaml.safe_load(f)
253
 
254
  # data related
255
+ num_frames = len(data_files)
256
  img_size = params['img_size']
257
  bands = params['bands']
258
  mean = params['data_mean']
 
272
 
273
  mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
274
 
275
+ print(f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n")
276
+ if len(data_files) != 3:
277
+ print("The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary")
278
 
279
  if torch.cuda.is_available():
280
  device = torch.device('cuda')
 
311
  model.to(device)
312
 
313
  state_dict = torch.load(checkpoint, map_location=device)
314
+ # discard fixed pos_embedding weight
315
+ del state_dict['pos_embed']
316
+ del state_dict['decoder_pos_embed']
317
+ model.load_state_dict(state_dict, strict=False)
318
  print(f"Loaded checkpoint from {checkpoint}")
319
 
320
  # Running model --------------------------------------------------------------------------------