Upload burn_scar_batch_inference_script.py
Browse files
burn_scar_batch_inference_script.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from mmcv import Config
|
3 |
+
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,wrap_fp16_model)
|
4 |
+
from mmseg.models import build_segmentor
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import mmcv
|
8 |
+
import torch
|
9 |
+
from mmcv.parallel import collate, scatter
|
10 |
+
from mmcv.runner import load_checkpoint
|
11 |
+
|
12 |
+
from mmseg.datasets.pipelines import Compose
|
13 |
+
from mmseg.models import build_segmentor
|
14 |
+
|
15 |
+
from mmseg.datasets import build_dataloader, build_dataset, load_flood_test_data
|
16 |
+
import rasterio
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
from torchvision import transforms
|
21 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
22 |
+
|
23 |
+
from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor
|
24 |
+
from mmseg.utils import custom # custom preprocessing for hls
|
25 |
+
import pdb
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import glob
|
29 |
+
import os
|
30 |
+
|
31 |
+
import time
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
|
35 |
+
parser = argparse.ArgumentParser(description="Inference on burn scar fine-tuned model")
|
36 |
+
parser.add_argument('-config', help='path to model configuration file')
|
37 |
+
parser.add_argument('-ckpt', help='path to model checkpoint')
|
38 |
+
parser.add_argument('-input', help='path to input images folder for inference')
|
39 |
+
parser.add_argument('-output', help='directory path to save output images')
|
40 |
+
parser.add_argument('-input_type', help='file type of input images',default="tif")
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
return args
|
45 |
+
|
46 |
+
def open_tiff(fname):
|
47 |
+
|
48 |
+
with rasterio.open(fname, "r") as src:
|
49 |
+
|
50 |
+
data = src.read()
|
51 |
+
|
52 |
+
return data
|
53 |
+
|
54 |
+
def write_tiff(img_wrt, filename, metadata):
|
55 |
+
|
56 |
+
"""
|
57 |
+
It writes a raster image to file.
|
58 |
+
|
59 |
+
:param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
|
60 |
+
:param filename: file path to the output file
|
61 |
+
:param metadata: metadata to use to write the raster to disk
|
62 |
+
:return:
|
63 |
+
"""
|
64 |
+
|
65 |
+
with rasterio.open(filename, "w", **metadata) as dest:
|
66 |
+
|
67 |
+
if len(img_wrt.shape) == 2:
|
68 |
+
|
69 |
+
img_wrt = img_wrt[None]
|
70 |
+
|
71 |
+
for i in range(img_wrt.shape[0]):
|
72 |
+
dest.write(img_wrt[i, :, :], i + 1)
|
73 |
+
|
74 |
+
|
75 |
+
def get_meta(fname):
|
76 |
+
|
77 |
+
with rasterio.open(fname, "r") as src:
|
78 |
+
|
79 |
+
meta = src.meta
|
80 |
+
|
81 |
+
return meta
|
82 |
+
|
83 |
+
def preprocess_image(data, means, stds, nodata=-9999):
|
84 |
+
|
85 |
+
data=np.where(data == nodata, 0, data)
|
86 |
+
data = data.astype(np.float32)
|
87 |
+
|
88 |
+
if len(data)==2:
|
89 |
+
(x, y) = data
|
90 |
+
else:
|
91 |
+
x=data
|
92 |
+
y=np.full((x.shape[-2], x.shape[-1]), -1)
|
93 |
+
|
94 |
+
im, label = x.copy(), y.copy()
|
95 |
+
label = label.astype(np.float64)
|
96 |
+
|
97 |
+
im1 = im[0] # red
|
98 |
+
im2 = im[1] # green
|
99 |
+
im3 = im[2] # blue
|
100 |
+
im4 = im[3] # NIR narrow
|
101 |
+
im5 = im[4] # swir 1
|
102 |
+
im6 = im[5] # swir 2
|
103 |
+
|
104 |
+
dim = x.shape[-1]
|
105 |
+
label = label.squeeze()
|
106 |
+
norm = transforms.Normalize(means, stds)
|
107 |
+
ims = [torch.stack((transforms.ToTensor()(im1).squeeze(),
|
108 |
+
transforms.ToTensor()(im2).squeeze(),
|
109 |
+
transforms.ToTensor()(im3).squeeze(),
|
110 |
+
transforms.ToTensor()(im4).squeeze(),
|
111 |
+
transforms.ToTensor()(im5).squeeze(),
|
112 |
+
transforms.ToTensor()(im6).squeeze()))]
|
113 |
+
ims = [norm(im) for im in ims]
|
114 |
+
ims = torch.stack(ims)
|
115 |
+
|
116 |
+
label = transforms.ToTensor()(label).squeeze()
|
117 |
+
|
118 |
+
_img_metas = {
|
119 |
+
'ori_shape': (dim, dim),
|
120 |
+
'img_shape': (dim, dim),
|
121 |
+
'pad_shape': (dim, dim),
|
122 |
+
'scale_factor': [1., 1., 1., 1.],
|
123 |
+
'flip': False, # needs flip direction specified
|
124 |
+
}
|
125 |
+
|
126 |
+
img_metas = [_img_metas] * 1
|
127 |
+
return {"img": ims,
|
128 |
+
"img_metas": img_metas,
|
129 |
+
"gt_semantic_seg": label}
|
130 |
+
|
131 |
+
|
132 |
+
def load_model(config, ckpt):
|
133 |
+
|
134 |
+
print('Loading configuration...')
|
135 |
+
cfg = Config.fromfile(config)
|
136 |
+
print('Building model...')
|
137 |
+
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
|
138 |
+
print('Loading checkpoint...')
|
139 |
+
checkpoint = load_checkpoint(model,ckpt, map_location='cpu')
|
140 |
+
print('Evaluating model...')
|
141 |
+
model = MMDataParallel(model, device_ids=[0])
|
142 |
+
model.eval()
|
143 |
+
|
144 |
+
return model
|
145 |
+
|
146 |
+
|
147 |
+
def inference_on_file(model, target_image, output_image, means, stds):
|
148 |
+
|
149 |
+
try:
|
150 |
+
st = time.time()
|
151 |
+
data_orig = open_tiff(target_image)
|
152 |
+
meta = get_meta(target_image)
|
153 |
+
nodata = meta['nodata'] if meta['nodata'] is not None else -9999
|
154 |
+
|
155 |
+
data = preprocess_image(data_orig, means, stds, nodata)
|
156 |
+
|
157 |
+
small_fixed_size_arrs = custom.split_and_pad(data['img'][:,:,None,:,:], (1, 6, 1, 224, 224))
|
158 |
+
single_chip_batch = [torch.vstack([torch.tensor(t) for t in small_fixed_size_arrs])]
|
159 |
+
print('Running inference...')
|
160 |
+
with torch.no_grad():
|
161 |
+
result = model(single_chip_batch, data['img_metas'], return_loss=False, rescale=False)
|
162 |
+
print("Result: Unique Values: ",np.unique(result))
|
163 |
+
|
164 |
+
print("Output has shape: " + str(result[0].shape))
|
165 |
+
#### TO DO: Post process (e.g. morphological operations)
|
166 |
+
|
167 |
+
result = custom.merge_and_unpad(result, (data_orig.shape[-2],data_orig.shape[-1]), (224, 224))
|
168 |
+
|
169 |
+
print("Result: Unique Values: ",np.unique(result))
|
170 |
+
|
171 |
+
##### Save file to disk
|
172 |
+
meta["count"] = 1
|
173 |
+
meta["dtype"] = "int16"
|
174 |
+
meta["compress"] = "lzw"
|
175 |
+
meta["nodata"] = -1
|
176 |
+
meta["nodata"] = nodata
|
177 |
+
print('Saving output...')
|
178 |
+
# pdb.set_trace()
|
179 |
+
result = np.where(data_orig[0] == nodata, nodata, result)
|
180 |
+
|
181 |
+
write_tiff(result, output_image, meta)
|
182 |
+
et = time.time()
|
183 |
+
print(f'Inference completed in {str(np.round(et - st, 1))} seconds. Output available at: ' + output_image)
|
184 |
+
|
185 |
+
except:
|
186 |
+
print(f'Error on image {target_image} \nContinue to next input')
|
187 |
+
|
188 |
+
def main():
|
189 |
+
|
190 |
+
args = parse_args()
|
191 |
+
|
192 |
+
model = load_model(args.config, args.ckpt)
|
193 |
+
image_pattern = "*merged"
|
194 |
+
target_images = glob.glob(os.path.join(args.input, image_pattern + "." + args.input_type))
|
195 |
+
|
196 |
+
print('Identified images to predict on: ' + str(len(target_images)))
|
197 |
+
|
198 |
+
if not os.path.isdir(args.output):
|
199 |
+
os.mkdir(args.output)
|
200 |
+
|
201 |
+
means, stds = custom.calculate_band_statistics(args.input, image_pattern, bands=[0, 1, 2, 3, 4, 5])
|
202 |
+
|
203 |
+
for i, target_image in enumerate(target_images):
|
204 |
+
|
205 |
+
print(f'Working on Image {i}')
|
206 |
+
output_image = os.path.join(args.output,target_image.split("/")[-1].split(f"_{image_pattern[1:]}.")[0]+'_pred.'+args.input_type)
|
207 |
+
|
208 |
+
inference_on_file(model, target_image, output_image, means, stds)
|
209 |
+
|
210 |
+
print("Running metric eval")
|
211 |
+
|
212 |
+
gt_dir = "/home/workdir/hls-foundation/data/burn_scars/validation"
|
213 |
+
pred_dir = args.output
|
214 |
+
avg_dice_score = custom.compute_metrics(gt_dir, pred_dir)
|
215 |
+
print("Average Dice score:", avg_dice_score)
|
216 |
+
|
217 |
+
|
218 |
+
if __name__ == "__main__":
|
219 |
+
main()
|