carlosgomes98
commited on
Commit
•
c2435ef
1
Parent(s):
ffd2463
fix inference
Browse files- Prithvi_100M_config.yaml +18 -18
- Prithvi_run_inference.py +231 -113
- README.md +3 -1
Prithvi_100M_config.yaml
CHANGED
@@ -12,25 +12,25 @@ model_args:
|
|
12 |
tubelet_size: 1
|
13 |
train_params:
|
14 |
bands:
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
data_mean:
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
data_std:
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
mask_ratio: 0.75
|
36 |
random_cropping: true
|
|
|
12 |
tubelet_size: 1
|
13 |
train_params:
|
14 |
bands:
|
15 |
+
- B02
|
16 |
+
- B03
|
17 |
+
- B04
|
18 |
+
- B05
|
19 |
+
- B06
|
20 |
+
- B07
|
21 |
data_mean:
|
22 |
+
- 775.2290211032589
|
23 |
+
- 1080.992780391705
|
24 |
+
- 1228.5855250417867
|
25 |
+
- 2497.2022620507532
|
26 |
+
- 2204.2139147975554
|
27 |
+
- 1610.8324823273745
|
28 |
data_std:
|
29 |
+
- 1281.526139861424
|
30 |
+
- 1270.0297974547493
|
31 |
+
- 1399.4802505642526
|
32 |
+
- 1368.3446143747644
|
33 |
+
- 1291.6764008585435
|
34 |
+
- 1154.505683480695
|
35 |
mask_ratio: 0.75
|
36 |
random_cropping: true
|
Prithvi_run_inference.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import argparse
|
2 |
import functools
|
3 |
import os
|
4 |
-
from typing import List
|
5 |
|
6 |
import numpy as np
|
7 |
import rasterio
|
@@ -17,7 +17,7 @@ PERCENTILES = (0.1, 99.9)
|
|
17 |
|
18 |
|
19 |
def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
20 |
-
"""
|
21 |
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
22 |
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
23 |
|
@@ -65,7 +65,7 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
|
65 |
|
66 |
|
67 |
def read_geotiff(file_path: str):
|
68 |
-
"""
|
69 |
|
70 |
Args:
|
71 |
file_path: path to image file.
|
@@ -83,7 +83,7 @@ def read_geotiff(file_path: str):
|
|
83 |
|
84 |
|
85 |
def save_geotiff(image, output_path: str, meta: dict):
|
86 |
-
"""
|
87 |
|
88 |
Args:
|
89 |
image: np.ndarray with shape (bands, height, width)
|
@@ -99,15 +99,19 @@ def save_geotiff(image, output_path: str, meta: dict):
|
|
99 |
|
100 |
|
101 |
def _convert_np_uint8(float_image: torch.Tensor):
|
102 |
-
|
103 |
image = float_image.numpy() * 255.0
|
104 |
image = image.astype(dtype=np.uint8)
|
105 |
|
106 |
return image
|
107 |
|
108 |
|
109 |
-
def load_example(
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
Args:
|
113 |
file_paths: list of file paths .
|
@@ -126,21 +130,28 @@ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
|
|
126 |
img, meta = read_geotiff(file)
|
127 |
|
128 |
# Rescaling (don't normalize on nodata)
|
129 |
-
img = np.moveaxis(img, 0, -1)
|
|
|
|
|
130 |
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
131 |
|
132 |
imgs.append(img)
|
133 |
metas.append(meta)
|
134 |
|
135 |
-
imgs = np.stack(imgs, axis=0)
|
136 |
-
imgs = np.moveaxis(imgs, -1, 0).astype(
|
137 |
imgs = np.expand_dims(imgs, axis=0) # add batch dim
|
138 |
|
139 |
return imgs, metas
|
140 |
|
141 |
|
142 |
-
def run_model(
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
Args:
|
146 |
model: MAE model to run.
|
@@ -158,12 +169,16 @@ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: floa
|
|
158 |
_, pred, mask = model(x, mask_ratio)
|
159 |
|
160 |
# Create mask and prediction images (un-patchify)
|
161 |
-
mask_img =
|
|
|
|
|
162 |
pred_img = model.unpatchify(pred).detach().cpu()
|
163 |
|
164 |
# Mix visible and predicted patches
|
165 |
rec_img = input_data.clone()
|
166 |
-
rec_img[mask_img == 1] = pred_img[
|
|
|
|
|
167 |
|
168 |
# Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
|
169 |
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
|
@@ -171,8 +186,10 @@ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: floa
|
|
171 |
return rec_img, mask_img
|
172 |
|
173 |
|
174 |
-
def save_rgb_imgs(
|
175 |
-
|
|
|
|
|
176 |
|
177 |
Args:
|
178 |
input_img: input torch.Tensor with shape (C, T, H, W).
|
@@ -186,30 +203,39 @@ def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir,
|
|
186 |
"""
|
187 |
|
188 |
for t in range(input_img.shape[1]):
|
189 |
-
rgb_orig, rgb_pred = process_channel_group(
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
193 |
|
194 |
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
195 |
|
196 |
# Saving images
|
197 |
|
198 |
-
save_geotiff(
|
199 |
-
|
200 |
-
|
|
|
|
|
201 |
|
202 |
-
save_geotiff(
|
203 |
-
|
204 |
-
|
|
|
|
|
205 |
|
206 |
-
save_geotiff(
|
207 |
-
|
208 |
-
|
|
|
|
|
209 |
|
210 |
|
211 |
def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
212 |
-
"""
|
213 |
|
214 |
Args:
|
215 |
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
@@ -224,7 +250,6 @@ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
|
224 |
std = torch.tensor(np.asarray(std)[:, None, None])
|
225 |
|
226 |
for t in range(rec_img.shape[1]):
|
227 |
-
|
228 |
# Back to original data range
|
229 |
rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
|
230 |
|
@@ -232,78 +257,98 @@ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
|
232 |
|
233 |
# Saving images
|
234 |
|
235 |
-
save_geotiff(
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
os.makedirs(output_dir, exist_ok=True)
|
248 |
|
249 |
# Get parameters --------
|
250 |
|
251 |
-
with open(yaml_file_path,
|
252 |
params = yaml.safe_load(f)
|
253 |
|
254 |
# data related
|
|
|
255 |
num_frames = len(data_files)
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
std = params['data_std']
|
260 |
|
261 |
# model related
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
276 |
if len(data_files) != 3:
|
277 |
-
print(
|
|
|
|
|
278 |
|
279 |
if torch.cuda.is_available():
|
280 |
-
device = torch.device(
|
281 |
else:
|
282 |
-
device = torch.device(
|
283 |
|
284 |
print(f"Using {device} device.\n")
|
285 |
|
286 |
# Loading data ---------------------------------------------------------------------------------
|
287 |
|
288 |
-
input_data, meta_data = load_example(
|
|
|
|
|
289 |
|
290 |
# Create model and load checkpoint -------------------------------------------------------------
|
291 |
|
292 |
model = MaskedAutoencoderViT(
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
307 |
|
308 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
309 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
@@ -312,27 +357,31 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
312 |
|
313 |
state_dict = torch.load(checkpoint, map_location=device)
|
314 |
# discard fixed pos_embedding weight
|
315 |
-
del state_dict[
|
316 |
-
del state_dict[
|
317 |
model.load_state_dict(state_dict, strict=False)
|
318 |
print(f"Loaded checkpoint from {checkpoint}")
|
319 |
|
320 |
# Running model --------------------------------------------------------------------------------
|
321 |
|
322 |
model.eval()
|
323 |
-
channels = [bands.index(b) for b in [
|
324 |
|
325 |
# Reflect pad if not divisible by img_size
|
326 |
original_h, original_w = input_data.shape[-2:]
|
327 |
pad_h = img_size - (original_h % img_size)
|
328 |
pad_w = img_size - (original_w % img_size)
|
329 |
-
input_data = np.pad(
|
|
|
|
|
330 |
|
331 |
# Build sliding window
|
332 |
-
batch = torch.tensor(input_data, device=
|
333 |
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
334 |
h1, w1 = windows.shape[3:5]
|
335 |
-
windows = rearrange(
|
|
|
|
|
336 |
|
337 |
# Split into batches if number of windows > batch_size
|
338 |
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
@@ -350,10 +399,28 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
350 |
mask_imgs = torch.concat(mask_imgs, dim=0)
|
351 |
|
352 |
# Build images from patches
|
353 |
-
rec_imgs = rearrange(
|
354 |
-
|
355 |
-
|
356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
|
358 |
# Cut padded images back to original size
|
359 |
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
@@ -363,37 +430,88 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
363 |
# Build output images
|
364 |
if rgb_outputs:
|
365 |
for d in meta_data:
|
366 |
-
d.update(count=3, dtype=
|
367 |
-
|
368 |
-
save_rgb_imgs(
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
else:
|
371 |
for d in meta_data:
|
372 |
-
d.update(compress=
|
373 |
|
374 |
-
save_imgs(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
print("Done!")
|
377 |
|
378 |
|
379 |
if __name__ == "__main__":
|
380 |
-
parser = argparse.ArgumentParser(
|
381 |
-
|
382 |
-
parser.add_argument(
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
args = parser.parse_args()
|
397 |
|
398 |
main(**vars(args))
|
399 |
-
|
|
|
1 |
import argparse
|
2 |
import functools
|
3 |
import os
|
4 |
+
from typing import List, Union
|
5 |
|
6 |
import numpy as np
|
7 |
import rasterio
|
|
|
17 |
|
18 |
|
19 |
def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
20 |
+
"""Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
21 |
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
22 |
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
23 |
|
|
|
65 |
|
66 |
|
67 |
def read_geotiff(file_path: str):
|
68 |
+
"""Read all bands from *file_path* and return image + meta info.
|
69 |
|
70 |
Args:
|
71 |
file_path: path to image file.
|
|
|
83 |
|
84 |
|
85 |
def save_geotiff(image, output_path: str, meta: dict):
|
86 |
+
"""Save multi-band image in Geotiff file.
|
87 |
|
88 |
Args:
|
89 |
image: np.ndarray with shape (bands, height, width)
|
|
|
99 |
|
100 |
|
101 |
def _convert_np_uint8(float_image: torch.Tensor):
|
|
|
102 |
image = float_image.numpy() * 255.0
|
103 |
image = image.astype(dtype=np.uint8)
|
104 |
|
105 |
return image
|
106 |
|
107 |
|
108 |
+
def load_example(
|
109 |
+
file_paths: List[str],
|
110 |
+
mean: List[float],
|
111 |
+
std: List[float],
|
112 |
+
indices: Union[list[int], None] = None,
|
113 |
+
):
|
114 |
+
"""Build an input example by loading images in *file_paths*.
|
115 |
|
116 |
Args:
|
117 |
file_paths: list of file paths .
|
|
|
130 |
img, meta = read_geotiff(file)
|
131 |
|
132 |
# Rescaling (don't normalize on nodata)
|
133 |
+
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
134 |
+
if indices is not None:
|
135 |
+
img = img[..., indices]
|
136 |
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
137 |
|
138 |
imgs.append(img)
|
139 |
metas.append(meta)
|
140 |
|
141 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
142 |
+
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
143 |
imgs = np.expand_dims(imgs, axis=0) # add batch dim
|
144 |
|
145 |
return imgs, metas
|
146 |
|
147 |
|
148 |
+
def run_model(
|
149 |
+
model: torch.nn.Module,
|
150 |
+
input_data: torch.Tensor,
|
151 |
+
mask_ratio: float,
|
152 |
+
device: torch.device,
|
153 |
+
):
|
154 |
+
"""Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
|
155 |
|
156 |
Args:
|
157 |
model: MAE model to run.
|
|
|
169 |
_, pred, mask = model(x, mask_ratio)
|
170 |
|
171 |
# Create mask and prediction images (un-patchify)
|
172 |
+
mask_img = (
|
173 |
+
model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
|
174 |
+
)
|
175 |
pred_img = model.unpatchify(pred).detach().cpu()
|
176 |
|
177 |
# Mix visible and predicted patches
|
178 |
rec_img = input_data.clone()
|
179 |
+
rec_img[mask_img == 1] = pred_img[
|
180 |
+
mask_img == 1
|
181 |
+
] # binary mask: 0 is keep, 1 is remove
|
182 |
|
183 |
# Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
|
184 |
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
|
|
|
186 |
return rec_img, mask_img
|
187 |
|
188 |
|
189 |
+
def save_rgb_imgs(
|
190 |
+
input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
|
191 |
+
):
|
192 |
+
"""Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
193 |
|
194 |
Args:
|
195 |
input_img: input torch.Tensor with shape (C, T, H, W).
|
|
|
203 |
"""
|
204 |
|
205 |
for t in range(input_img.shape[1]):
|
206 |
+
rgb_orig, rgb_pred = process_channel_group(
|
207 |
+
orig_img=input_img[:, t, :, :],
|
208 |
+
new_img=rec_img[:, t, :, :],
|
209 |
+
channels=channels,
|
210 |
+
data_mean=mean,
|
211 |
+
data_std=std,
|
212 |
+
)
|
213 |
|
214 |
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
215 |
|
216 |
# Saving images
|
217 |
|
218 |
+
save_geotiff(
|
219 |
+
image=_convert_np_uint8(rgb_orig),
|
220 |
+
output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
|
221 |
+
meta=meta_data[t],
|
222 |
+
)
|
223 |
|
224 |
+
save_geotiff(
|
225 |
+
image=_convert_np_uint8(rgb_pred),
|
226 |
+
output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
|
227 |
+
meta=meta_data[t],
|
228 |
+
)
|
229 |
|
230 |
+
save_geotiff(
|
231 |
+
image=_convert_np_uint8(rgb_mask),
|
232 |
+
output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
|
233 |
+
meta=meta_data[t],
|
234 |
+
)
|
235 |
|
236 |
|
237 |
def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
238 |
+
"""Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
|
239 |
|
240 |
Args:
|
241 |
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
|
|
250 |
std = torch.tensor(np.asarray(std)[:, None, None])
|
251 |
|
252 |
for t in range(rec_img.shape[1]):
|
|
|
253 |
# Back to original data range
|
254 |
rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
|
255 |
|
|
|
257 |
|
258 |
# Saving images
|
259 |
|
260 |
+
save_geotiff(
|
261 |
+
image=rec_img_t,
|
262 |
+
output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
|
263 |
+
meta=meta_data[t],
|
264 |
+
)
|
265 |
+
|
266 |
+
save_geotiff(
|
267 |
+
image=mask_img_t,
|
268 |
+
output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
|
269 |
+
meta=meta_data[t],
|
270 |
+
)
|
271 |
+
|
272 |
+
|
273 |
+
def main(
|
274 |
+
data_files: List[str],
|
275 |
+
yaml_file_path: str,
|
276 |
+
checkpoint: str,
|
277 |
+
output_dir: str,
|
278 |
+
rgb_outputs: bool,
|
279 |
+
img_size: int,
|
280 |
+
mask_ratio: float = None,
|
281 |
+
input_indices: list[int] = None,
|
282 |
+
):
|
283 |
os.makedirs(output_dir, exist_ok=True)
|
284 |
|
285 |
# Get parameters --------
|
286 |
|
287 |
+
with open(yaml_file_path, "r") as f:
|
288 |
params = yaml.safe_load(f)
|
289 |
|
290 |
# data related
|
291 |
+
train_params = params["train_params"]
|
292 |
num_frames = len(data_files)
|
293 |
+
bands = train_params["bands"]
|
294 |
+
mean = train_params["data_mean"]
|
295 |
+
std = train_params["data_std"]
|
|
|
296 |
|
297 |
# model related
|
298 |
+
model_params = params["model_args"]
|
299 |
+
img_size = model_params["img_size"] if img_size is None else img_size
|
300 |
+
depth = model_params["depth"]
|
301 |
+
patch_size = model_params["patch_size"]
|
302 |
+
embed_dim = model_params["embed_dim"]
|
303 |
+
num_heads = model_params["num_heads"]
|
304 |
+
tubelet_size = model_params["tubelet_size"]
|
305 |
+
decoder_embed_dim = model_params["decoder_embed_dim"]
|
306 |
+
decoder_num_heads = model_params["decoder_num_heads"]
|
307 |
+
decoder_depth = model_params["decoder_depth"]
|
308 |
+
|
309 |
+
batch_size = 1
|
310 |
+
|
311 |
+
mask_ratio = train_params["mask_ratio"] if mask_ratio is None else mask_ratio
|
312 |
+
|
313 |
+
print(
|
314 |
+
f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
|
315 |
+
)
|
316 |
if len(data_files) != 3:
|
317 |
+
print(
|
318 |
+
"The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary"
|
319 |
+
)
|
320 |
|
321 |
if torch.cuda.is_available():
|
322 |
+
device = torch.device("cuda")
|
323 |
else:
|
324 |
+
device = torch.device("cpu")
|
325 |
|
326 |
print(f"Using {device} device.\n")
|
327 |
|
328 |
# Loading data ---------------------------------------------------------------------------------
|
329 |
|
330 |
+
input_data, meta_data = load_example(
|
331 |
+
file_paths=data_files, indices=input_indices, mean=mean, std=std
|
332 |
+
)
|
333 |
|
334 |
# Create model and load checkpoint -------------------------------------------------------------
|
335 |
|
336 |
model = MaskedAutoencoderViT(
|
337 |
+
img_size=img_size,
|
338 |
+
patch_size=patch_size,
|
339 |
+
num_frames=num_frames,
|
340 |
+
tubelet_size=tubelet_size,
|
341 |
+
in_chans=len(bands),
|
342 |
+
embed_dim=embed_dim,
|
343 |
+
depth=depth,
|
344 |
+
num_heads=num_heads,
|
345 |
+
decoder_embed_dim=decoder_embed_dim,
|
346 |
+
decoder_depth=decoder_depth,
|
347 |
+
decoder_num_heads=decoder_num_heads,
|
348 |
+
mlp_ratio=4.0,
|
349 |
+
norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
|
350 |
+
norm_pix_loss=False,
|
351 |
+
)
|
352 |
|
353 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
354 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
|
|
357 |
|
358 |
state_dict = torch.load(checkpoint, map_location=device)
|
359 |
# discard fixed pos_embedding weight
|
360 |
+
del state_dict["pos_embed"]
|
361 |
+
del state_dict["decoder_pos_embed"]
|
362 |
model.load_state_dict(state_dict, strict=False)
|
363 |
print(f"Loaded checkpoint from {checkpoint}")
|
364 |
|
365 |
# Running model --------------------------------------------------------------------------------
|
366 |
|
367 |
model.eval()
|
368 |
+
channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
|
369 |
|
370 |
# Reflect pad if not divisible by img_size
|
371 |
original_h, original_w = input_data.shape[-2:]
|
372 |
pad_h = img_size - (original_h % img_size)
|
373 |
pad_w = img_size - (original_w % img_size)
|
374 |
+
input_data = np.pad(
|
375 |
+
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
376 |
+
)
|
377 |
|
378 |
# Build sliding window
|
379 |
+
batch = torch.tensor(input_data, device="cpu")
|
380 |
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
381 |
h1, w1 = windows.shape[3:5]
|
382 |
+
windows = rearrange(
|
383 |
+
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
384 |
+
)
|
385 |
|
386 |
# Split into batches if number of windows > batch_size
|
387 |
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
|
|
399 |
mask_imgs = torch.concat(mask_imgs, dim=0)
|
400 |
|
401 |
# Build images from patches
|
402 |
+
rec_imgs = rearrange(
|
403 |
+
rec_imgs,
|
404 |
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
405 |
+
h=img_size,
|
406 |
+
w=img_size,
|
407 |
+
b=1,
|
408 |
+
c=len(bands),
|
409 |
+
t=num_frames,
|
410 |
+
h1=h1,
|
411 |
+
w1=w1,
|
412 |
+
)
|
413 |
+
mask_imgs = rearrange(
|
414 |
+
mask_imgs,
|
415 |
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
416 |
+
h=img_size,
|
417 |
+
w=img_size,
|
418 |
+
b=1,
|
419 |
+
c=len(bands),
|
420 |
+
t=num_frames,
|
421 |
+
h1=h1,
|
422 |
+
w1=w1,
|
423 |
+
)
|
424 |
|
425 |
# Cut padded images back to original size
|
426 |
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
|
|
430 |
# Build output images
|
431 |
if rgb_outputs:
|
432 |
for d in meta_data:
|
433 |
+
d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
434 |
+
|
435 |
+
save_rgb_imgs(
|
436 |
+
batch_full[0, ...],
|
437 |
+
rec_imgs_full[0, ...],
|
438 |
+
mask_imgs_full[0, ...],
|
439 |
+
channels,
|
440 |
+
mean,
|
441 |
+
std,
|
442 |
+
output_dir,
|
443 |
+
meta_data,
|
444 |
+
)
|
445 |
else:
|
446 |
for d in meta_data:
|
447 |
+
d.update(compress="lzw", nodata=0)
|
448 |
|
449 |
+
save_imgs(
|
450 |
+
rec_imgs_full[0, ...],
|
451 |
+
mask_imgs_full[0, ...],
|
452 |
+
mean,
|
453 |
+
std,
|
454 |
+
output_dir,
|
455 |
+
meta_data,
|
456 |
+
)
|
457 |
|
458 |
print("Done!")
|
459 |
|
460 |
|
461 |
if __name__ == "__main__":
|
462 |
+
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
463 |
+
|
464 |
+
parser.add_argument(
|
465 |
+
"--data_files",
|
466 |
+
required=True,
|
467 |
+
type=str,
|
468 |
+
nargs="+",
|
469 |
+
help="Path to the data files. Assumes multi-band files.",
|
470 |
+
)
|
471 |
+
parser.add_argument(
|
472 |
+
"--yaml_file_path",
|
473 |
+
type=str,
|
474 |
+
required=True,
|
475 |
+
help="Path to yaml file containing model training parameters.",
|
476 |
+
)
|
477 |
+
parser.add_argument(
|
478 |
+
"--checkpoint",
|
479 |
+
required=True,
|
480 |
+
type=str,
|
481 |
+
help="Path to a checkpoint file to load from.",
|
482 |
+
)
|
483 |
+
parser.add_argument(
|
484 |
+
"--output_dir",
|
485 |
+
required=True,
|
486 |
+
type=str,
|
487 |
+
help="Path to the directory where to save outputs.",
|
488 |
+
)
|
489 |
+
parser.add_argument(
|
490 |
+
"--mask_ratio",
|
491 |
+
default=None,
|
492 |
+
type=float,
|
493 |
+
help="Masking ratio (percentage of removed patches). "
|
494 |
+
"If None (default) use same value used for pretraining.",
|
495 |
+
)
|
496 |
+
parser.add_argument(
|
497 |
+
"--img_size",
|
498 |
+
default=224,
|
499 |
+
type=int,
|
500 |
+
help="Image size to be used with model. Defaults to 224",
|
501 |
+
)
|
502 |
+
parser.add_argument(
|
503 |
+
"--input_indices",
|
504 |
+
default=None,
|
505 |
+
type=int,
|
506 |
+
nargs="+",
|
507 |
+
help="0-based indices of channels to be selected from the input. By default takes all.",
|
508 |
+
)
|
509 |
+
parser.add_argument(
|
510 |
+
"--rgb_outputs",
|
511 |
+
action="store_true",
|
512 |
+
help="If present, output files will only contain RGB channels. "
|
513 |
+
"Otherwise, all bands will be saved.",
|
514 |
+
)
|
515 |
args = parser.parse_args()
|
516 |
|
517 |
main(**vars(args))
|
|
README.md
CHANGED
@@ -36,9 +36,11 @@ The model follows the [original MAE repo](https://github.com/facebookresearch/ma
|
|
36 |
There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
|
37 |
|
38 |
```
|
39 |
-
python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --mask_ratio 0.5
|
40 |
```
|
41 |
|
|
|
|
|
42 |
### Finetuning examples
|
43 |
Examples of finetuning the model for image segmentation using the mmsegmentation library are available through Hugging Face (e.g. [burn scars segmentation](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar), [flood mapping](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11), and [multi temporal crop classification](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification)), with the code used for the experiments available on [github](https://github.com/NASA-IMPACT/hls-foundation-os/tree/main/fine-tuning-examples). This also contains instructions to finetune the model for flood detection on the popular open access [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
|
44 |
|
|
|
36 |
There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
|
37 |
|
38 |
```
|
39 |
+
python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --input_indices <space separated 0-based indices of channels to select from input> --mask_ratio 0.5 --img_size <length of one side of square input shape>
|
40 |
```
|
41 |
|
42 |
+
This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
|
43 |
+
|
44 |
### Finetuning examples
|
45 |
Examples of finetuning the model for image segmentation using the mmsegmentation library are available through Hugging Face (e.g. [burn scars segmentation](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar), [flood mapping](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11), and [multi temporal crop classification](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification)), with the code used for the experiments available on [github](https://github.com/NASA-IMPACT/hls-foundation-os/tree/main/fine-tuning-examples). This also contains instructions to finetune the model for flood detection on the popular open access [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
|
46 |
|