carlosgomes98
commited on
Commit
•
67661a3
1
Parent(s):
9e63f80
remove timestep restriction
Browse files- Prithvi_run_inference.py +8 -4
Prithvi_run_inference.py
CHANGED
@@ -252,7 +252,7 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
252 |
params = yaml.safe_load(f)
|
253 |
|
254 |
# data related
|
255 |
-
num_frames =
|
256 |
img_size = params['img_size']
|
257 |
bands = params['bands']
|
258 |
mean = params['data_mean']
|
@@ -272,8 +272,9 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
272 |
|
273 |
mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
|
274 |
|
275 |
-
|
276 |
-
|
|
|
277 |
|
278 |
if torch.cuda.is_available():
|
279 |
device = torch.device('cuda')
|
@@ -310,7 +311,10 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
310 |
model.to(device)
|
311 |
|
312 |
state_dict = torch.load(checkpoint, map_location=device)
|
313 |
-
|
|
|
|
|
|
|
314 |
print(f"Loaded checkpoint from {checkpoint}")
|
315 |
|
316 |
# Running model --------------------------------------------------------------------------------
|
|
|
252 |
params = yaml.safe_load(f)
|
253 |
|
254 |
# data related
|
255 |
+
num_frames = len(data_files)
|
256 |
img_size = params['img_size']
|
257 |
bands = params['bands']
|
258 |
mean = params['data_mean']
|
|
|
272 |
|
273 |
mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
|
274 |
|
275 |
+
print(f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n")
|
276 |
+
if len(data_files) != 3:
|
277 |
+
print("The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary")
|
278 |
|
279 |
if torch.cuda.is_available():
|
280 |
device = torch.device('cuda')
|
|
|
311 |
model.to(device)
|
312 |
|
313 |
state_dict = torch.load(checkpoint, map_location=device)
|
314 |
+
# discard fixed pos_embedding weight
|
315 |
+
del state_dict['pos_embed']
|
316 |
+
del state_dict['decoder_pos_embed']
|
317 |
+
model.load_state_dict(state_dict, strict=False)
|
318 |
print(f"Loaded checkpoint from {checkpoint}")
|
319 |
|
320 |
# Running model --------------------------------------------------------------------------------
|