Upload burn_scar_batch_inference_script.py
#3
by
rbavery
- opened
- README.md +36 -0
- burn_scar_batch_inference_script.py +219 -0
- custom.py +191 -0
- requirements.txt +47 -0
README.md
CHANGED
@@ -33,6 +33,42 @@ Code for Finetuning is available through [github](https://github.com/NASA-IMPACT
|
|
33 |
Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/firescars_config.py
|
34 |
)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
### Results
|
38 |
|
|
|
33 |
Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/firescars_config.py
|
34 |
)
|
35 |
|
36 |
+
To run inference, first install dependencies
|
37 |
+
|
38 |
+
```
|
39 |
+
mamba create -n prithvi-burn-scar python=3.10 pycocotools ncurses
|
40 |
+
mamba activate prithvi-burn-scar
|
41 |
+
pip install --upgrade pip && \
|
42 |
+
pip install -r requirements.txt && \
|
43 |
+
mim install mmcv-full==1.5.0
|
44 |
+
```
|
45 |
+
|
46 |
+
#### Instructions for downloading from [HuggingFace datasets](https://huggingface.co/datasets)
|
47 |
+
|
48 |
+
1. Create account on https://huggingface.co/join
|
49 |
+
2. Install `git` following https://git-scm.com/downloads
|
50 |
+
3. Install git-lfs with `sudo apt install git-lfs` and `git lfs install`
|
51 |
+
4. Run the following command to download the HLS datasets. You may need to
|
52 |
+
enter your HuggingFace username/password to do the `git clone`.
|
53 |
+
|
54 |
+
```
|
55 |
+
mkdir -p data
|
56 |
+
cd data/
|
57 |
+
git clone https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars burn_scars
|
58 |
+
tar -xzvf burn_scars/hls_burn_scars.tar.gz -C ./
|
59 |
+
```
|
60 |
+
|
61 |
+
|
62 |
+
With the datasets and the environment, you can now run the inference script.
|
63 |
+
|
64 |
+
```
|
65 |
+
python burn_scar_batch_inference_script.py \
|
66 |
+
-config burn_scars_Prithvi_100M.py \
|
67 |
+
-ckpt burn_scars_Prithvi_100M.pth \
|
68 |
+
-input data/burn_scars/validation \
|
69 |
+
-output data/burn_scars/inference_output \
|
70 |
+
-input_type tif
|
71 |
+
```
|
72 |
|
73 |
### Results
|
74 |
|
burn_scar_batch_inference_script.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from mmcv import Config
|
3 |
+
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,wrap_fp16_model)
|
4 |
+
from mmseg.models import build_segmentor
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import mmcv
|
8 |
+
import torch
|
9 |
+
from mmcv.parallel import collate, scatter
|
10 |
+
from mmcv.runner import load_checkpoint
|
11 |
+
|
12 |
+
from mmseg.datasets.pipelines import Compose
|
13 |
+
from mmseg.models import build_segmentor
|
14 |
+
|
15 |
+
from mmseg.datasets import build_dataloader, build_dataset, load_flood_test_data
|
16 |
+
import rasterio
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
from torchvision import transforms
|
21 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
22 |
+
|
23 |
+
from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor
|
24 |
+
from . import custom # custom preprocessing for hls
|
25 |
+
import pdb
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import glob
|
29 |
+
import os
|
30 |
+
|
31 |
+
import time
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
|
35 |
+
parser = argparse.ArgumentParser(description="Inference on burn scar fine-tuned model")
|
36 |
+
parser.add_argument('-config', help='path to model configuration file')
|
37 |
+
parser.add_argument('-ckpt', help='path to model checkpoint')
|
38 |
+
parser.add_argument('-input', help='path to input images folder for inference')
|
39 |
+
parser.add_argument('-output', help='directory path to save output images')
|
40 |
+
parser.add_argument('-input_type', help='file type of input images',default="tif")
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
return args
|
45 |
+
|
46 |
+
def open_tiff(fname):
|
47 |
+
|
48 |
+
with rasterio.open(fname, "r") as src:
|
49 |
+
|
50 |
+
data = src.read()
|
51 |
+
|
52 |
+
return data
|
53 |
+
|
54 |
+
def write_tiff(img_wrt, filename, metadata):
|
55 |
+
|
56 |
+
"""
|
57 |
+
It writes a raster image to file.
|
58 |
+
|
59 |
+
:param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
|
60 |
+
:param filename: file path to the output file
|
61 |
+
:param metadata: metadata to use to write the raster to disk
|
62 |
+
:return:
|
63 |
+
"""
|
64 |
+
|
65 |
+
with rasterio.open(filename, "w", **metadata) as dest:
|
66 |
+
|
67 |
+
if len(img_wrt.shape) == 2:
|
68 |
+
|
69 |
+
img_wrt = img_wrt[None]
|
70 |
+
|
71 |
+
for i in range(img_wrt.shape[0]):
|
72 |
+
dest.write(img_wrt[i, :, :], i + 1)
|
73 |
+
|
74 |
+
|
75 |
+
def get_meta(fname):
|
76 |
+
|
77 |
+
with rasterio.open(fname, "r") as src:
|
78 |
+
|
79 |
+
meta = src.meta
|
80 |
+
|
81 |
+
return meta
|
82 |
+
|
83 |
+
def preprocess_image(data, means, stds, nodata=-9999):
|
84 |
+
|
85 |
+
data=np.where(data == nodata, 0, data)
|
86 |
+
data = data.astype(np.float32)
|
87 |
+
|
88 |
+
if len(data)==2:
|
89 |
+
(x, y) = data
|
90 |
+
else:
|
91 |
+
x=data
|
92 |
+
y=np.full((x.shape[-2], x.shape[-1]), -1)
|
93 |
+
|
94 |
+
im, label = x.copy(), y.copy()
|
95 |
+
label = label.astype(np.float64)
|
96 |
+
|
97 |
+
im1 = im[0] # red
|
98 |
+
im2 = im[1] # green
|
99 |
+
im3 = im[2] # blue
|
100 |
+
im4 = im[3] # NIR narrow
|
101 |
+
im5 = im[4] # swir 1
|
102 |
+
im6 = im[5] # swir 2
|
103 |
+
|
104 |
+
dim = x.shape[-1]
|
105 |
+
label = label.squeeze()
|
106 |
+
norm = transforms.Normalize(means, stds)
|
107 |
+
ims = [torch.stack((transforms.ToTensor()(im1).squeeze(),
|
108 |
+
transforms.ToTensor()(im2).squeeze(),
|
109 |
+
transforms.ToTensor()(im3).squeeze(),
|
110 |
+
transforms.ToTensor()(im4).squeeze(),
|
111 |
+
transforms.ToTensor()(im5).squeeze(),
|
112 |
+
transforms.ToTensor()(im6).squeeze()))]
|
113 |
+
ims = [norm(im) for im in ims]
|
114 |
+
ims = torch.stack(ims)
|
115 |
+
|
116 |
+
label = transforms.ToTensor()(label).squeeze()
|
117 |
+
|
118 |
+
_img_metas = {
|
119 |
+
'ori_shape': (dim, dim),
|
120 |
+
'img_shape': (dim, dim),
|
121 |
+
'pad_shape': (dim, dim),
|
122 |
+
'scale_factor': [1., 1., 1., 1.],
|
123 |
+
'flip': False, # needs flip direction specified
|
124 |
+
}
|
125 |
+
|
126 |
+
img_metas = [_img_metas] * 1
|
127 |
+
return {"img": ims,
|
128 |
+
"img_metas": img_metas,
|
129 |
+
"gt_semantic_seg": label}
|
130 |
+
|
131 |
+
|
132 |
+
def load_model(config, ckpt):
|
133 |
+
|
134 |
+
print('Loading configuration...')
|
135 |
+
cfg = Config.fromfile(config)
|
136 |
+
print('Building model...')
|
137 |
+
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
|
138 |
+
print('Loading checkpoint...')
|
139 |
+
checkpoint = load_checkpoint(model,ckpt, map_location='cpu')
|
140 |
+
print('Evaluating model...')
|
141 |
+
model = MMDataParallel(model, device_ids=[0])
|
142 |
+
model.eval()
|
143 |
+
|
144 |
+
return model
|
145 |
+
|
146 |
+
|
147 |
+
def inference_on_file(model, target_image, output_image, means, stds):
|
148 |
+
|
149 |
+
try:
|
150 |
+
st = time.time()
|
151 |
+
data_orig = open_tiff(target_image)
|
152 |
+
meta = get_meta(target_image)
|
153 |
+
nodata = meta['nodata'] if meta['nodata'] is not None else -9999
|
154 |
+
|
155 |
+
data = preprocess_image(data_orig, means, stds, nodata)
|
156 |
+
|
157 |
+
small_fixed_size_arrs = custom.split_and_pad(data['img'][:,:,None,:,:], (1, 6, 1, 224, 224))
|
158 |
+
single_chip_batch = [torch.vstack([torch.tensor(t) for t in small_fixed_size_arrs])]
|
159 |
+
print('Running inference...')
|
160 |
+
with torch.no_grad():
|
161 |
+
result = model(single_chip_batch, data['img_metas'], return_loss=False, rescale=False)
|
162 |
+
print("Result: Unique Values: ",np.unique(result))
|
163 |
+
|
164 |
+
print("Output has shape: " + str(result[0].shape))
|
165 |
+
#### TO DO: Post process (e.g. morphological operations)
|
166 |
+
|
167 |
+
result = custom.merge_and_unpad(result, (data_orig.shape[-2],data_orig.shape[-1]), (224, 224))
|
168 |
+
|
169 |
+
print("Result: Unique Values: ",np.unique(result))
|
170 |
+
|
171 |
+
##### Save file to disk
|
172 |
+
meta["count"] = 1
|
173 |
+
meta["dtype"] = "int16"
|
174 |
+
meta["compress"] = "lzw"
|
175 |
+
meta["nodata"] = -1
|
176 |
+
meta["nodata"] = nodata
|
177 |
+
print('Saving output...')
|
178 |
+
# pdb.set_trace()
|
179 |
+
result = np.where(data_orig[0] == nodata, nodata, result)
|
180 |
+
|
181 |
+
write_tiff(result, output_image, meta)
|
182 |
+
et = time.time()
|
183 |
+
print(f'Inference completed in {str(np.round(et - st, 1))} seconds. Output available at: ' + output_image)
|
184 |
+
|
185 |
+
except:
|
186 |
+
print(f'Error on image {target_image} \nContinue to next input')
|
187 |
+
|
188 |
+
def main():
|
189 |
+
|
190 |
+
args = parse_args()
|
191 |
+
|
192 |
+
model = load_model(args.config, args.ckpt)
|
193 |
+
image_pattern = "*merged"
|
194 |
+
target_images = glob.glob(os.path.join(args.input, image_pattern + "." + args.input_type))
|
195 |
+
|
196 |
+
print('Identified images to predict on: ' + str(len(target_images)))
|
197 |
+
|
198 |
+
if not os.path.isdir(args.output):
|
199 |
+
os.mkdir(args.output)
|
200 |
+
|
201 |
+
means, stds = custom.calculate_band_statistics(args.input, image_pattern, bands=[0, 1, 2, 3, 4, 5])
|
202 |
+
|
203 |
+
for i, target_image in enumerate(target_images):
|
204 |
+
|
205 |
+
print(f'Working on Image {i}')
|
206 |
+
output_image = os.path.join(args.output,target_image.split("/")[-1].split(f"_{image_pattern[1:]}.")[0]+'_pred.'+args.input_type)
|
207 |
+
|
208 |
+
inference_on_file(model, target_image, output_image, means, stds)
|
209 |
+
|
210 |
+
print("Running metric eval")
|
211 |
+
|
212 |
+
gt_dir = "/home/workdir/hls-foundation/data/burn_scars/validation"
|
213 |
+
pred_dir = args.output
|
214 |
+
avg_dice_score = custom.compute_metrics(gt_dir, pred_dir)
|
215 |
+
print("Average Dice score:", avg_dice_score)
|
216 |
+
|
217 |
+
|
218 |
+
if __name__ == "__main__":
|
219 |
+
main()
|
custom.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils.py
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import glob
|
5 |
+
import rasterio
|
6 |
+
from torchvision import transforms
|
7 |
+
import torch
|
8 |
+
import re
|
9 |
+
from torchmetrics import Dice
|
10 |
+
import os
|
11 |
+
|
12 |
+
def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]):
|
13 |
+
"""
|
14 |
+
Calculate the mean and standard deviation of each band in a folder of GeoTIFF files.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training.
|
18 |
+
image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats.
|
19 |
+
bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5].
|
20 |
+
|
21 |
+
Raises:
|
22 |
+
Exception: If no images are found in the given directory.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
tuple: Two lists containing the means and standard deviations of each band.
|
26 |
+
"""
|
27 |
+
# Initialize lists to store the means and standard deviations
|
28 |
+
all_means = []
|
29 |
+
all_stds = []
|
30 |
+
|
31 |
+
# Use glob to get a list of all .tif images in the directory
|
32 |
+
all_images = glob.glob(f"{image_directory}/{image_pattern}.tif")
|
33 |
+
|
34 |
+
# Make sure there are images to process
|
35 |
+
if not all_images:
|
36 |
+
raise Exception("No images found")
|
37 |
+
|
38 |
+
# Get the number of bands
|
39 |
+
num_bands = len(bands)
|
40 |
+
|
41 |
+
# Initialize arrays to hold sums and sum of squares for each band
|
42 |
+
band_sums = np.zeros(num_bands)
|
43 |
+
band_sq_sums = np.zeros(num_bands)
|
44 |
+
pixel_counts = np.zeros(num_bands)
|
45 |
+
|
46 |
+
# Iterate over each image
|
47 |
+
for image_file in all_images:
|
48 |
+
with rasterio.open(image_file) as src:
|
49 |
+
# For each band, calculate the sum, square sum, and pixel count
|
50 |
+
for band in bands:
|
51 |
+
data = src.read(band + 1) # rasterio band index starts from 1
|
52 |
+
band_sums[band] += np.nansum(data)
|
53 |
+
band_sq_sums[band] += np.nansum(data**2)
|
54 |
+
pixel_counts[band] += np.count_nonzero(~np.isnan(data))
|
55 |
+
|
56 |
+
# Calculate means and standard deviations for each band
|
57 |
+
for i in bands:
|
58 |
+
mean = band_sums[i] / pixel_counts[i]
|
59 |
+
std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2))
|
60 |
+
all_means.append(mean)
|
61 |
+
all_stds.append(std)
|
62 |
+
|
63 |
+
return all_means, all_stds
|
64 |
+
|
65 |
+
|
66 |
+
def split_and_pad(array, target_shape):
|
67 |
+
"""
|
68 |
+
Splits the input array into smaller arrays of the target shape, padding if necessary.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
array (numpy.ndarray): The input array. Must be shape (batch, band, time, height, width)
|
72 |
+
target_shape (tuple): The target shape of the smaller arrays. Must be of shape
|
73 |
+
(batch, band, time, height, width)
|
74 |
+
|
75 |
+
Raises:
|
76 |
+
ValueError: If target shape is larger than the array shape.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
list[numpy.ndarray]: A list of the smaller arrays.
|
80 |
+
"""
|
81 |
+
# Check if the target shape is smaller or equal to the array shape
|
82 |
+
if target_shape[-2:] > array.shape[-2:]:
|
83 |
+
raise ValueError('Target shape must be smaller or equal to the array shape.')
|
84 |
+
|
85 |
+
# Calculate how much padding is needed
|
86 |
+
pad_h = (target_shape[-2] - array.shape[-2] % target_shape[-2]) % target_shape[-2]
|
87 |
+
pad_w = (target_shape[-1] - array.shape[-1] % target_shape[-1]) % target_shape[-1]
|
88 |
+
|
89 |
+
# Apply padding to the array
|
90 |
+
padded_array = np.pad(array, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)))
|
91 |
+
|
92 |
+
# Split the array into smaller arrays of the target shape
|
93 |
+
result = []
|
94 |
+
for i in range(0, padded_array.shape[-2], target_shape[-2]):
|
95 |
+
for j in range(0, padded_array.shape[-1], target_shape[-1]):
|
96 |
+
result.append(padded_array[..., i:i+target_shape[-2], j:j+target_shape[-1]])
|
97 |
+
|
98 |
+
return result
|
99 |
+
|
100 |
+
def merge_and_unpad(np_array_list, original_shape, target_shape):
|
101 |
+
"""
|
102 |
+
Assembles smaller numpy arrays back into the original larger numpy array, removing padding if necessary.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
np_array_list (list[numpy.ndarray]): The list of smaller numpy arrays derived from split_and_pad.
|
106 |
+
original_shape (tuple): The original shape of the larger numpy array. Must be shape (Height, Width).
|
107 |
+
target_shape (tuple): The target shape of the smaller numpy arrays. Must be shape (Height, Width).
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
numpy.ndarray: The original larger numpy array.
|
111 |
+
"""
|
112 |
+
# Calculate how much padding was added
|
113 |
+
pad_h = (target_shape[0] - original_shape[0] % target_shape[0]) % target_shape[0]
|
114 |
+
pad_w = (target_shape[1] - original_shape[1] % target_shape[1]) % target_shape[1]
|
115 |
+
|
116 |
+
# Calculate the shape of the padded larger array
|
117 |
+
padded_shape = (original_shape[0] + pad_h, original_shape[1] + pad_w)
|
118 |
+
|
119 |
+
# Calculate the number of smaller arrays in each dimension
|
120 |
+
num_arrays_h = padded_shape[0] // target_shape[0]
|
121 |
+
num_arrays_w = padded_shape[1] // target_shape[1]
|
122 |
+
|
123 |
+
# Reshape the list of smaller arrays back into the shape of the padded larger array
|
124 |
+
merged_array = np.stack(np_array_list).reshape(num_arrays_h, num_arrays_w, *target_shape)
|
125 |
+
|
126 |
+
# Rearrange the array dimensions
|
127 |
+
merged_array = merged_array.transpose(0, 2, 1, 3).reshape(*padded_shape)
|
128 |
+
|
129 |
+
# Remove the padding
|
130 |
+
unpadded_array = merged_array[:original_shape[0], :original_shape[1]]
|
131 |
+
|
132 |
+
return unpadded_array
|
133 |
+
|
134 |
+
def compute_metrics(gt_dir, pred_dir):
|
135 |
+
"""
|
136 |
+
Compute the Dice similarity coefficient between the predicted and ground truth images.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
gt_dir (str): Directory where the ground truth images are stored.
|
140 |
+
pred_dir (str): Directory where the predicted images are stored.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Tensor: Dice similarity coefficient score.
|
144 |
+
"""
|
145 |
+
dice_metric = Dice()
|
146 |
+
|
147 |
+
# find all .tif files in the prediction directory
|
148 |
+
pred_files = glob.glob(os.path.join(pred_dir, "*.tif"))
|
149 |
+
|
150 |
+
# iterate over each prediction file
|
151 |
+
for pred_file in pred_files:
|
152 |
+
# extract the unique_id from the file name
|
153 |
+
unique_id = re.search('HLS\..*\.v1\.4', os.path.basename(pred_file))
|
154 |
+
|
155 |
+
if unique_id is not None:
|
156 |
+
unique_id = unique_id.group()
|
157 |
+
|
158 |
+
# create the unique pattern for the gt directory
|
159 |
+
gt_file_pattern = os.path.join(gt_dir, f"*{unique_id}*mask.tif")
|
160 |
+
|
161 |
+
# glob the file pattern
|
162 |
+
gt_files = glob.glob(gt_file_pattern)
|
163 |
+
|
164 |
+
# if we found a matching gt file
|
165 |
+
if len(gt_files) == 1:
|
166 |
+
gt_file = gt_files[0]
|
167 |
+
|
168 |
+
# read the .tif files
|
169 |
+
with rasterio.open(gt_file) as src:
|
170 |
+
gt_img = src.read(1) # ground truth image
|
171 |
+
|
172 |
+
with rasterio.open(pred_file) as src:
|
173 |
+
pred_img = src.read(1) # predicted image
|
174 |
+
|
175 |
+
# make sure the images are binary (values are 0 or 1)
|
176 |
+
gt_img = (gt_img > 0).astype(np.uint8)
|
177 |
+
pred_img = (pred_img > 0).astype(np.uint8)
|
178 |
+
|
179 |
+
# convert numpy arrays to PyTorch tensors
|
180 |
+
gt_img_tensor = torch.from_numpy(gt_img).long().flatten()
|
181 |
+
pred_img_tensor = torch.from_numpy(pred_img).long().flatten()
|
182 |
+
|
183 |
+
# update dice_metric
|
184 |
+
dice_metric.update(pred_img_tensor, gt_img_tensor)
|
185 |
+
|
186 |
+
else:
|
187 |
+
print(f"No matching ground truth file for prediction file {pred_file}.")
|
188 |
+
|
189 |
+
# compute the dice score
|
190 |
+
dice_score = dice_metric.compute()
|
191 |
+
return dice_score
|
requirements.txt
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
boxsdk==3.6.2
|
2 |
+
cityscapesscripts==2.2.1
|
3 |
+
codecov
|
4 |
+
detail==0.2.2
|
5 |
+
docutils==0.16.0
|
6 |
+
einops==0.6.0
|
7 |
+
flake8
|
8 |
+
interrogate
|
9 |
+
jupyterlab==4.0.1
|
10 |
+
matplotlib==3.5.1
|
11 |
+
mmcls>=0.20.1
|
12 |
+
mmdet==2.22.0
|
13 |
+
model_archiver==1.0.3
|
14 |
+
myst-parser
|
15 |
+
-e git+https://github.com/gaotongxiao/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
16 |
+
natsort==8.3.1
|
17 |
+
numpy==1.21.6
|
18 |
+
onnx==1.13.1
|
19 |
+
onnxruntime==1.14.1
|
20 |
+
onnx2torch
|
21 |
+
opencv-python==4.7.0.72
|
22 |
+
openmim
|
23 |
+
packaging==21.3
|
24 |
+
pandas==1.3.5
|
25 |
+
pavi==0.0.1
|
26 |
+
Pillow==9.4.0
|
27 |
+
pip-tools
|
28 |
+
prettytable==3.6.0
|
29 |
+
pytest==7.1.3
|
30 |
+
rasterio==1.3.4
|
31 |
+
requests==2.28.2
|
32 |
+
scikit-learn
|
33 |
+
scipy==1.7.3
|
34 |
+
scikit-image
|
35 |
+
seaborn==0.12.2
|
36 |
+
sphinx==4.0.2
|
37 |
+
sphinx_copybutton
|
38 |
+
sphinx_markdown_tables
|
39 |
+
tensorrt==8.5.3.1
|
40 |
+
timm==0.4.12
|
41 |
+
torch==1.9.0+cu111
|
42 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
43 |
+
torchvision==0.10.0
|
44 |
+
torchmetrics
|
45 |
+
ts==0.5.1
|
46 |
+
xdoctest>=0.10.0
|
47 |
+
yapf
|