|
import rasterio as rio |
|
import pathlib |
|
import opensr_test |
|
import matplotlib.pyplot as plt |
|
|
|
from typing import Callable, Union |
|
|
|
|
|
def create_geotiff( |
|
model: Callable, |
|
fn: Callable, |
|
datasets: Union[str, list], |
|
output_path: str, |
|
force: bool = False, |
|
**kwargs |
|
) -> None: |
|
"""Create all the GeoTIFFs for a specific dataset snippet |
|
|
|
Args: |
|
model (Callable): The model to use to run the fn function. |
|
fn (Callable): A function that return a dictionary with the following keys: |
|
- "lr": Low resolution image |
|
- "sr": Super resolution image |
|
- "hr": High resolution image |
|
datasets (list): A list of dataset snippets to use to run the fn function. |
|
output_path (str): The output path to save the GeoTIFFs. |
|
force (bool, optional): If True, the dataset is redownloaded. Defaults |
|
to False. |
|
""" |
|
|
|
if datasets == "all": |
|
datasets = opensr_test.datasets |
|
|
|
for snippet in datasets: |
|
create_geotiff_batch( |
|
model=model, |
|
fn=fn, |
|
snippet=snippet, |
|
output_path=output_path, |
|
force=force, |
|
**kwargs |
|
) |
|
|
|
return None |
|
|
|
def create_geotiff_batch( |
|
model: Callable, |
|
fn: Callable, |
|
snippet: str, |
|
output_path: str, |
|
force: bool = False, |
|
**kwargs |
|
) -> pathlib.Path: |
|
"""Create all the GeoTIFFs for a specific dataset snippet |
|
|
|
Args: |
|
model (Callable): The model to use to run the fn function. |
|
fn (Callable): A function that return a dictionary with the following keys: |
|
- "lr": Low resolution image |
|
- "sr": Super resolution image |
|
- "hr": High resolution image |
|
snippet (str): The dataset snippet to use to run the fn function. |
|
output_path (str): The output path to save the GeoTIFFs. |
|
force (bool, optional): If True, the dataset is redownloaded. Defaults |
|
to False. |
|
|
|
Returns: |
|
pathlib.Path: The output path where the GeoTIFFs are saved. |
|
""" |
|
|
|
|
|
output_path = pathlib.Path(output_path) / "results" / "SR" |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
output_path_dataset_geotiff = output_path / snippet / "geotiff" |
|
output_path_dataset_geotiff.mkdir(parents=True, exist_ok=True) |
|
|
|
output_path_dataset_png = output_path / snippet / "png" |
|
output_path_dataset_png.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
dataset = opensr_test.load(snippet, force=force) |
|
lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"] |
|
for index in range(len(lr_dataset)): |
|
print(f"Processing {index}/{len(lr_dataset)}") |
|
|
|
|
|
results = fn( |
|
model=model, |
|
lr=lr_dataset[index], |
|
hr=hr_dataset[index], |
|
**kwargs |
|
) |
|
|
|
|
|
image_name = metadata.iloc[index]["hr_file"] |
|
|
|
|
|
crs = metadata.iloc[index]["crs"] |
|
transform_str = metadata.iloc[index]["affine"] |
|
transform_list = [float(x) for x in transform_str.split(",")] |
|
transform_rio = rio.transform.from_origin( |
|
transform_list[2], |
|
transform_list[5], |
|
transform_list[0], |
|
transform_list[4] * -1 |
|
) |
|
|
|
|
|
meta_img = { |
|
"driver": "GTiff", |
|
"count": 3, |
|
"dtype": "uint16", |
|
"height": results["hr"].shape[1], |
|
"width": results["hr"].shape[2], |
|
"crs": crs, |
|
"transform": transform_rio, |
|
"compress": "deflate", |
|
"predictor": 2, |
|
"tiled": True |
|
} |
|
|
|
|
|
with rio.open(output_path_dataset_geotiff / (image_name + ".tif"), "w", **meta_img) as dst: |
|
dst.write(results["sr"]) |
|
|
|
|
|
fig, ax = plt.subplots(1, 3, figsize=(15, 5)) |
|
ax[0].imshow(results["lr"].transpose(1, 2, 0) / 3000) |
|
ax[0].set_title("LR") |
|
ax[0].axis("off") |
|
ax[1].imshow(results["sr"].transpose(1, 2, 0) / 3000) |
|
ax[1].set_title("SR") |
|
ax[1].axis("off") |
|
ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000) |
|
ax[2].set_title("HR") |
|
|
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
plt.axis("off") |
|
plt.savefig(output_path_dataset_png / (image_name + ".png")) |
|
plt.close() |
|
plt.clf() |
|
|
|
return output_path_dataset_geotiff |
|
|
|
|
|
|
|
|
|
def run( |
|
model_path: str |
|
) -> pathlib.Path: |
|
"""Run the all metrics for a specific model. |
|
|
|
Args: |
|
model_path (str): The path to the model folder. |
|
|
|
Returns: |
|
pathlib.Path: The output path where the metrics are |
|
saved as a pickle file. |
|
""" |
|
pass |
|
|
|
|
|
def plot( |
|
model_path: str |
|
) -> pathlib.Path: |
|
"""Generate the plots and tables for a specific model. |
|
|
|
Args: |
|
model_path (str): The path to the model folder. |
|
|
|
Returns: |
|
pathlib.Path: The output path where the plots and tables are |
|
saved. |
|
""" |
|
pass |
|
|
|
|
|
|