Paolo-Fraccaro commited on
Commit
13d7913
1 Parent(s): 2e4b359

add inference script

Browse files
Files changed (1) hide show
  1. Prithvi_run_inference.py +339 -0
Prithvi_run_inference.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import rasterio
8
+ import torch
9
+ import yaml
10
+ from einops import rearrange
11
+
12
+ from Prithvi import MaskedAutoencoderViT
13
+
14
+
15
+ NO_DATA = -9999
16
+ NO_DATA_FLOAT = 0.0001
17
+ PERCENTILES = (0.1, 99.9)
18
+
19
+
20
+ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
21
+ """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
22
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
23
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
24
+ Args:
25
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
26
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
27
+ channels: list of indices representing RGB channels.
28
+ data_mean: list of mean values for each band.
29
+ data_std: list of std values for each band.
30
+ Returns:
31
+ torch.Tensor with shape (num_channels, height, width) for original image
32
+ torch.Tensor with shape (num_channels, height, width) for the other image
33
+ """
34
+
35
+ stack_c = [], []
36
+
37
+ for c in channels:
38
+ orig_ch = orig_img[c, ...]
39
+ valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
40
+ valid_mask[orig_ch == 0.0001] = False
41
+
42
+ # Back to original data range
43
+ orig_ch = (orig_ch * data_std[c]) + data_mean[c]
44
+ new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
45
+
46
+ # Rescale (enhancing contrast)
47
+ min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
48
+
49
+ orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
50
+ new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
51
+
52
+ # No data as zeros
53
+ orig_ch[~valid_mask] = 0
54
+ new_ch[~valid_mask] = 0
55
+
56
+ stack_c[0].append(orig_ch)
57
+ stack_c[1].append(new_ch)
58
+
59
+ # Channels first
60
+ stack_orig = torch.stack(stack_c[0], dim=0)
61
+ stack_rec = torch.stack(stack_c[1], dim=0)
62
+
63
+ return stack_orig, stack_rec
64
+
65
+
66
+ def read_geotiff(file_path: str):
67
+ """ Read all bands from *file_path* and returns image + meta info.
68
+ Args:
69
+ file_path: path to image file.
70
+ Returns:
71
+ np.ndarray with shape (bands, height, width)
72
+ meta info dict
73
+ """
74
+
75
+ with rasterio.open(file_path) as src:
76
+ img = src.read()
77
+ meta = src.meta
78
+
79
+ return img, meta
80
+
81
+
82
+ def save_geotiff(image, output_path: str, meta: dict):
83
+ """ Save multi-band image in Geotiff file.
84
+ Args:
85
+ image: np.ndarray with shape (bands, height, width)
86
+ output_path: path where to save the image
87
+ meta: dict with meta info.
88
+ """
89
+
90
+ with rasterio.open(output_path, "w", **meta) as dest:
91
+ for i in range(image.shape[0]):
92
+ dest.write(image[i, :, :], i + 1)
93
+
94
+ return
95
+
96
+
97
+ def _convert_np_uint8(float_image: torch.Tensor):
98
+
99
+ image = float_image.numpy() * 255.0
100
+ image = image.astype(dtype=np.uint8)
101
+
102
+ return image
103
+
104
+
105
+ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
106
+ """ Build an input example by loading images in *file_paths*.
107
+ Args:
108
+ file_paths: list of file paths .
109
+ mean: list containing mean values for each band in the images in *file_paths*.
110
+ std: list containing std values for each band in the images in *file_paths*.
111
+ Returns:
112
+ np.array containing created example
113
+ list of meta info for each image in *file_paths*
114
+ """
115
+
116
+ imgs = []
117
+ metas = []
118
+
119
+ for file in file_paths:
120
+ img, meta = read_geotiff(file)
121
+
122
+ # Rescaling (don't normalize on nodata)
123
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
124
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
125
+
126
+ imgs.append(img)
127
+ metas.append(meta)
128
+
129
+ imgs = np.stack(imgs, axis=0) # num_frames, img_size, img_size, C
130
+ imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, img_size, img_size
131
+ imgs = np.expand_dims(imgs, axis=0) # add batch dim
132
+
133
+ return imgs, metas
134
+
135
+
136
+ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
137
+ """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
138
+ Args:
139
+ model: MAE model to run.
140
+ input_data: torch.Tensor with shape (B, C, T, H, W).
141
+ mask_ratio: mask ratio to use.
142
+ device: device where model should run.
143
+ Returns:
144
+ 3 torch.Tensor with shape (B, C, T, H, W).
145
+ """
146
+
147
+ with torch.no_grad():
148
+ x = input_data.to(device)
149
+
150
+ _, pred, mask = model(x, mask_ratio)
151
+
152
+ # Create mask and prediction images (un-patchify)
153
+ mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
154
+ pred_img = model.unpatchify(pred).detach().cpu()
155
+
156
+ # Mix visible and predicted patches
157
+ rec_img = input_data.clone()
158
+ rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
159
+
160
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
161
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
162
+
163
+ return rec_img, mask_img
164
+
165
+
166
+ def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
167
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
168
+ Args:
169
+ input_img: input torch.Tensor with shape (C, T, H, W).
170
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
171
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
172
+ channels: list of indices representing RGB channels.
173
+ mean: list of mean values for each band.
174
+ std: list of std values for each band.
175
+ output_dir: directory where to save outputs.
176
+ meta_data: list of dicts with geotiff meta info.
177
+ """
178
+
179
+ for t in range(input_img.shape[1]):
180
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
181
+ new_img=rec_img[:, t, :, :],
182
+ channels=channels, data_mean=mean,
183
+ data_std=std)
184
+
185
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
186
+
187
+ # Saving images
188
+
189
+ save_geotiff(image=_convert_np_uint8(rgb_orig),
190
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
191
+ meta=meta_data[t])
192
+
193
+ save_geotiff(image=_convert_np_uint8(rgb_pred),
194
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
195
+ meta=meta_data[t])
196
+
197
+ save_geotiff(image=_convert_np_uint8(rgb_mask),
198
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
199
+ meta=meta_data[t])
200
+
201
+
202
+ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir: str, mask_ratio: float):
203
+
204
+ os.makedirs(output_dir, exist_ok=True)
205
+
206
+ # Get parameters --------
207
+
208
+ with open(yaml_file_path, 'r') as f:
209
+ params = yaml.safe_load(f)
210
+
211
+ # data related
212
+ num_frames = params['num_frames']
213
+ img_size = params['img_size']
214
+ bands = params['bands']
215
+ mean = params['data_mean']
216
+ std = params['data_std']
217
+
218
+ # model related
219
+ depth = params['depth']
220
+ patch_size = params['patch_size']
221
+ embed_dim = params['embed_dim']
222
+ num_heads = params['num_heads']
223
+ tubelet_size = params['tubelet_size']
224
+ decoder_embed_dim = params['decoder_embed_dim']
225
+ decoder_num_heads = params['decoder_num_heads']
226
+ decoder_depth = params['decoder_depth']
227
+
228
+ batch_size = params['batch_size']
229
+
230
+ mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
231
+
232
+ # We must have *num_frames* files to build one example!
233
+ assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
234
+
235
+ if torch.cuda.is_available():
236
+ device = torch.device('cuda')
237
+ else:
238
+ device = torch.device('cpu')
239
+
240
+ print(f"Using {device} device.\n")
241
+
242
+ # Loading data ---------------------------------------------------------------------------------
243
+
244
+ input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
245
+
246
+ # Create model and load checkpoint -------------------------------------------------------------
247
+
248
+ model = MaskedAutoencoderViT(
249
+ img_size=img_size,
250
+ patch_size=patch_size,
251
+ num_frames=num_frames,
252
+ tubelet_size=tubelet_size,
253
+ in_chans=len(bands),
254
+ embed_dim=embed_dim,
255
+ depth=depth,
256
+ num_heads=num_heads,
257
+ decoder_embed_dim=decoder_embed_dim,
258
+ decoder_depth=decoder_depth,
259
+ decoder_num_heads=decoder_num_heads,
260
+ mlp_ratio=4.,
261
+ norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
262
+ norm_pix_loss=False)
263
+
264
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
265
+ print(f"\n--> model has {total_params / 1e6} Million params.\n")
266
+
267
+ model.to(device)
268
+
269
+ state_dict = torch.load(checkpoint, map_location=device)
270
+ model.load_state_dict(state_dict)
271
+ print(f"Loaded checkpoint from {checkpoint}")
272
+
273
+ # Running model --------------------------------------------------------------------------------
274
+
275
+ model.eval()
276
+ channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
277
+
278
+ # Build sliding window
279
+ batch = torch.tensor(input_data, device='cpu')
280
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
281
+ h1, w1 = windows.shape[3:5]
282
+ windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
283
+
284
+ # Split into batches if number of windows > batch_size
285
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
286
+ windows = torch.tensor_split(windows, num_batches, dim=0)
287
+
288
+ # Run model
289
+ rec_imgs = []
290
+ mask_imgs = []
291
+ for x in windows:
292
+ rec_img, mask_img = run_model(model, x, mask_ratio, device)
293
+ rec_imgs.append(rec_img)
294
+ mask_imgs.append(mask_img)
295
+
296
+ rec_imgs = torch.concat(rec_imgs, dim=0)
297
+ mask_imgs = torch.concat(mask_imgs, dim=0)
298
+
299
+ # Build images from patches
300
+ rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
301
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
302
+ mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
303
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
304
+
305
+ # Mix original image with patches
306
+ h, w = rec_imgs.shape[-2:]
307
+ rec_imgs_full = batch.clone()
308
+ rec_imgs_full[..., :h, :w] = rec_imgs
309
+
310
+ mask_imgs_full = torch.ones_like(batch)
311
+ mask_imgs_full[..., :h, :w] = mask_imgs
312
+
313
+ # Build RGB images
314
+ for d in meta_data:
315
+ d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
316
+
317
+ save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
318
+ channels, mean, std, output_dir, meta_data)
319
+
320
+ print("Done!")
321
+
322
+
323
+ if __name__ == "__main__":
324
+ parser = argparse.ArgumentParser('MAE run inference', add_help=False)
325
+
326
+ parser.add_argument('--data_files', required=True, type=str, nargs='+',
327
+ help='Path to the data files. Assumes multi-band files.')
328
+ parser.add_argument('--yaml_file_path', type=str, required=True,
329
+ help='Path to yaml file containing model training parameters.')
330
+ parser.add_argument('--checkpoint', required=True, type=str,
331
+ help='Path to a checkpoint file to load from.')
332
+ parser.add_argument('--output_dir', required=True, type=str,
333
+ help='Path to the directory where to save outputs.')
334
+ parser.add_argument('--mask_ratio', default=None, type=float,
335
+ help='Masking ratio (percentage of removed patches). '
336
+ 'If None (default) use same value used for pretraining.')
337
+ args = parser.parse_args()
338
+
339
+ main(**vars(args))