File size: 4,416 Bytes
db24bac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# lightning.pytorch==2.1.1
seed_everything: 0
### Trainer configuration
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
# You can swtich to TensorBoard for logging by uncommenting the below line and commenting out the procedding line
#class_path: TensorBoardLogger
class_path: lightning.pytorch.loggers.csv_logs.CSVLogger
init_args:
save_dir: ./experiments
name: fine_tune_suhi
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 600
max_epochs: 600
check_val_every_n_epoch: 1
log_every_n_steps: 10
enable_checkpointing: true
default_root_dir: ./experiments
out_dtype: float32
### Data configuration
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 1
num_workers: 8
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
mask_value: 1
p: 0.5
- class_path: ToTensorV2
# Specify all bands which are in the input data.
dataset_bands:
# 6 HLS bands
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
# ERA5-Land t2m_spatial_avg
- 7
# ERA5-Land t2m_sunrise_avg
- 8
# ERA5-Land t2m_midnight_avg
- 9
# ERA5-Land t2m_delta_avg
- 10
# cos_tod
- 11
# sin_tod
- 12
# cos_doy
- 13
# sin_doy
- 14
# Specify the bands which are used from the input data.
# Bands 8 - 14 were discarded in the final model
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- 7
rgb_indices:
- 2
- 1
- 0
# Directory roots to training, validation and test datasplits:
train_data_root: train/inputs
train_label_data_root: train/targets
val_data_root: val/inputs
val_label_data_root: val/targets
test_data_root: test/inputs
test_label_data_root: test/targets
img_grep: "*.inputs.tif"
label_grep: "*.lst.tif"
# Nodata value in the input data
no_data_replace: 0
# Nodata value in label (target) data
no_label_replace: -9999
# Mean value of the training dataset per band
means:
- 702.4754028320312
- 1023.23291015625
- 1118.8924560546875
- 2440.750732421875
- 2052.705810546875
- 1514.15087890625
- 21.031919479370117
# Standard deviation of the training dataset per band
stds:
- 554.8255615234375
- 613.5565185546875
- 745.929443359375
- 715.0111083984375
- 761.47607421875
- 734.991943359375
- 8.66781997680664
### Model configuration
model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: false
backbone: prithvi_swin_L
img_size: 224
backbone_drop_path_rate: 0.3
decoder_channels: 256
in_channels: 7
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- 7
num_frames: 1
loss: rmse
aux_heads:
- name: aux_head
decoder: IdentityDecoder
decoder_args:
head_dropout: 0.5
head_channel_list:
- 1
head_final_act: torch.nn.LazyLinear
aux_loss:
aux_head: 0.4
ignore_index: -9999
freeze_backbone: false
freeze_decoder: false
model_factory: PrithviModelFactory
# This block is commented out when inferencing on full tiles.
# It is possible to inference on full tiles with this paramter on, the benefit is that the compute requirement is smaller.
# However, using this to inference on a full tile will introduce artefacting/"patchy" predictions.
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 224
# w_crop: 224
# w_stride: 224
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.0001
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
|