Spaces:
Sleeping
Sleeping
import random | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from pytorch_lightning import LightningDataModule | |
import os | |
import re | |
import yaml | |
import rasterio | |
import dvc.api | |
params = dvc.api.params_show() | |
N_TIMESTEPS = params['number_of_timesteps'] | |
class ToTensorTransform(object): | |
def __init__(self, dtype): | |
self.dtype = dtype | |
def __call__(self, data): | |
return torch.tensor(data, dtype=self.dtype) | |
class NormalizeTransform(object): | |
def __init__(self, means, stds): | |
self.mean = means | |
self.std = stds | |
def __call__(self, data): | |
return transforms.Normalize(self.mean, self.std)(data) | |
class PermuteTransform: | |
def __call__(self, data): | |
height, width = data.shape[-2:] | |
# Ensure the channel dimension is as expected | |
if data.shape[0] != N_TIMESTEPS * 6: | |
raise ValueError(f"Expected {N_TIMESTEPS*6} channels, got {data.shape[1]}") | |
# Step 1: Reshape the data to group the N_TIMESTEPS*6 bands into N_TIMESTEPS groups of 6 bands | |
data = data.view(N_TIMESTEPS, 6, height, width) | |
# Step 2: Permute to bring the bands to the front | |
data = data.permute(1, 0, 2, 3) # NOTE: Prithvi wants it bands first # after this, shape is (6, N_TIMESTEPS, height, width) | |
return data | |
class RandomFlipAndJitterTransform: | |
""" | |
Apply random horizontal and vertical flips, and channel jitter to the input image and corresponding mask. | |
Parameters: | |
----------- | |
flip_prob : float, optional (default=0.5) | |
Probability of applying horizontal and vertical flips to the image and mask. | |
Each flip (horizontal and vertical) is applied independently based on this probability. | |
jitter_std : float, optional (default=0.02) | |
Standard deviation of the Gaussian noise added to the image channels for jitter. | |
This value controls the intensity of the random noise applied to the image channels. | |
Effects of Parameters: | |
---------------------- | |
flip_prob: | |
- Higher flip_prob increases the likelihood of the image and mask being flipped. | |
- A value of 0 means no flipping, while a value of 1 means always flip. | |
jitter_std: | |
- Higher jitter_std increases the intensity of the noise added to the image channels. | |
- A value of 0 means no noise, while larger values add more significant noise. | |
""" | |
def __init__(self, flip_prob=0.5, jitter_std=0.02): | |
self.flip_prob = flip_prob | |
self.jitter_std = jitter_std | |
def __call__(self, img, mask, field_ids): | |
# Shapes (..., H, W)| img: torch.Size([6, N_TIMESTEPS, 224, 224]), mask: torch.Size([N_TIMESTEPS, 224, 224]), field_ids: torch.Size([1, 224, 224]) | |
# Temporarily convert field_ids to int32 for flipping (flip not implemented for uint16) | |
field_ids = field_ids.to(torch.int32) | |
# Random horizontal flip | |
if random.random() < self.flip_prob: | |
img = torch.flip(img, [2]) | |
mask = torch.flip(mask, [1]) | |
field_ids = torch.flip(field_ids, [1]) | |
# Random vertical flip | |
if random.random() < self.flip_prob: | |
img = torch.flip(img, [3]) | |
mask = torch.flip(mask, [2]) | |
field_ids = torch.flip(field_ids, [2]) | |
# Convert field_ids back to uint16 | |
field_ids = field_ids.to(torch.uint16) | |
# Channel jitter | |
noise = torch.randn(img.size()) * self.jitter_std | |
img += noise | |
return img, mask, field_ids | |
def get_img_transforms(): | |
return transforms.Compose([]) | |
def get_mask_transforms(): | |
return transforms.Compose([]) | |
class GeospatialDataset(Dataset): | |
def __init__(self, data_dir, fold_indicies, transform_img=None, transform_mask=None, transform_field_ids=None, debug=False, subset_size=None, data_augmentation=None): | |
self.data_dir = data_dir | |
self.chips_dir = os.path.join(data_dir, 'chips') | |
self.transform_img = transform_img | |
self.transform_mask = transform_mask | |
self.transform_field_ids = transform_field_ids | |
self.debug = debug | |
self.images = [] | |
self.masks = [] | |
self.field_ids = [] | |
self.data_augmentation = data_augmentation | |
self.means, self.stds = self.load_stats(fold_indicies, N_TIMESTEPS) | |
self.transform_img_load = self.get_img_load_transforms(self.means, self.stds) | |
self.transform_mask_load = self.get_mask_load_transforms() | |
self.transform_field_ids_load = self.get_field_ids_load_transforms() | |
# Adjust file selection based on fold | |
for file in os.listdir(self.chips_dir): | |
if re.match(f".*_fold_[{''.join([str(f) for f in fold_indicies])}]_merged.tif", file): | |
self.images.append(file) | |
mask_file = file.replace("_merged.tif", "_mask.tif") | |
self.masks.append(mask_file) | |
field_ids_file = file.replace("_merged.tif", "_field_ids.tif") | |
self.field_ids.append(field_ids_file) | |
assert len(self.images) == len(self.masks), "Number of images and masks do not match" | |
# If subset_size is specified, randomly select a subset of the data | |
if subset_size is not None and len(self.images) > subset_size: | |
print(f"Randomly selecting {subset_size} samples from {len(self.images)} samples.") | |
selected_indices = random.sample(range(len(self.images)), subset_size) | |
self.images = [self.images[i] for i in selected_indices] | |
self.masks = [self.masks[i] for i in selected_indices] | |
self.field_ids = [self.field_ids[i] for i in selected_indices] | |
def load_stats(self, fold_indicies, n_timesteps=3): | |
"""Load normalization statistics for dataset from YAML file.""" | |
stats_path = os.path.join(self.data_dir, 'chips_stats.yaml') | |
if self.debug: | |
print(f"Loading mean/std stats from {stats_path}") | |
assert os.path.exists(stats_path), f"mean/std stats file for dataset not found at {stats_path}" | |
with open(stats_path, 'r') as file: | |
stats = yaml.safe_load(file) | |
mean_list, std_list, n_list = [], [], [] | |
for fold in fold_indicies: | |
key = f'fold_{fold}' | |
if key not in stats: | |
raise ValueError(f"mean/std stats for fold {fold} not found in {stats_path}") | |
if self.debug: | |
print(f"Stats with selected test fold {fold}: {stats[key]} over {n_timesteps} timesteps.") | |
mean_list.append(torch.Tensor(stats[key]['mean'])) # list of 6 means | |
std_list.append(torch.Tensor(stats[key]['std'])) # list of 6 stds | |
n_list.append(stats[key]['n_chips']) # list of 6 ns | |
# aggregate means and stds over all folds | |
means, stds = [], [] | |
for channel in range(mean_list[0].shape[0]): | |
means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean()) | |
# stds are waaaay more complex to aggregate | |
# \sqrt{\frac{\sum_{i=1}^{n} (\sigma_i * (n_i - 1))}{\sum_{i=1}^{n} (n_i) - n}} | |
variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))]) | |
n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32) | |
combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list)) | |
stds.append(torch.sqrt(combined_variance)) | |
# make means and stds into 2d arrays, as torchvision would otherwise convert it into a 3d tensor which is incompatible with our 4d temporal images | |
# https://github.com/pytorch/vision/blob/6e18cea3485066b7277785415bf2e0422dbdb9da/torchvision/transforms/_functional_tensor.py#L923 | |
return means * n_timesteps, stds * n_timesteps | |
def get_img_load_transforms(self, means, stds): | |
return transforms.Compose([ | |
ToTensorTransform(torch.float32), | |
NormalizeTransform(means, stds), | |
PermuteTransform() | |
]) | |
def get_mask_load_transforms(self): | |
return transforms.Compose([ | |
ToTensorTransform(torch.uint8) | |
]) | |
def get_field_ids_load_transforms(self): | |
return transforms.Compose([ | |
ToTensorTransform(torch.uint16) | |
]) | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
img_path = os.path.join(self.chips_dir, self.images[idx]) | |
mask_path = os.path.join(self.chips_dir, self.masks[idx]) | |
field_ids_path = os.path.join(self.chips_dir, self.field_ids[idx]) | |
img = rasterio.open(img_path).read().astype('uint16') | |
mask = rasterio.open(mask_path).read().astype('uint8') | |
field_ids = rasterio.open(field_ids_path).read().astype('uint16') | |
# Apply our base transforms | |
img = self.transform_img_load(img) | |
mask = self.transform_mask_load(mask) | |
field_ids = self.transform_field_ids_load(field_ids) | |
# Apply additional transforms passed from GeospatialDataModule if applicable | |
if self.transform_img is not None: | |
img = self.transform_img(img) | |
if self.transform_mask is not None: | |
mask = self.transform_mask(mask) | |
if self.transform_field_ids is not None: | |
field_ids = self.transform_field_ids(field_ids) | |
# Apply data augmentation if enabled | |
if self.data_augmentation is not None and self.data_augmentation.get('enabled', True): | |
img, mask, field_ids = RandomFlipAndJitterTransform( | |
flip_prob=self.data_augmentation.get('flip_prob', 0.5), | |
jitter_std=self.data_augmentation.get('jitter_std', 0.02) | |
)(img, mask, field_ids) | |
# Load targets for given tiers | |
num_tiers = mask.shape[0] | |
targets = () | |
for i in range(num_tiers): | |
targets += (mask[i, :, :].type(torch.long),) | |
return img, (targets, field_ids) | |
class GeospatialDataModule(LightningDataModule): | |
def __init__(self, data_dir, train_folds, val_folds, test_folds, batch_size=8, num_workers=4, debug=False, subsets=None, data_augmentation=None): | |
super().__init__() | |
self.data_dir = data_dir | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.debug = debug | |
self.subsets = subsets if subsets is not None else {} | |
self.data_augmentation = data_augmentation if data_augmentation is not None else {} | |
GeospatialDataModule.validate_folds(train_folds, val_folds, test_folds) | |
self.train_folds = train_folds | |
self.val_folds = val_folds | |
self.test_folds = test_folds | |
# NOTE: Transforms on this level not used for now | |
self.transform_img = get_img_transforms() | |
self.transform_mask = get_mask_transforms() | |
def validate_folds(train, val, test): | |
if train is None or val is None or test is None: | |
raise ValueError("All fold sets must be specified") | |
if len(set(train) & set(val)) > 0 or len(set(train) & set(test)) > 0 or len(set(val) & set(test)) > 0: | |
raise ValueError("Folds must be mutually exclusive") | |
def setup(self, stage=None): | |
print(f"Setting up GeospatialDataModule for stage: {stage}. Data augmentation config: {self.data_augmentation}") | |
common_params = { | |
'data_dir': self.data_dir, | |
'debug': self.debug, | |
'data_augmentation': self.data_augmentation | |
} | |
common_params_val_test = { # Never augment validation or test data | |
**common_params, | |
'data_augmentation': { | |
'enabled': False | |
} | |
} | |
if stage in ('fit', None): | |
self.train_dataset = GeospatialDataset(fold_indicies=self.train_folds, subset_size=self.subsets.get('train', None), **common_params) | |
self.val_dataset = GeospatialDataset(fold_indicies=self.val_folds, subset_size=self.subsets.get('val', None), **common_params_val_test) | |
if stage in ('test', None): | |
self.test_dataset = GeospatialDataset(fold_indicies=self.test_folds, subset_size=self.subsets.get('test', None), **common_params_val_test) | |
def train_dataloader(self): | |
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, shuffle=True) | |
def val_dataloader(self): | |
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True) | |
def test_dataloader(self): | |
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True) | |