Carlos Gomes commited on
Commit
a82c831
1 Parent(s): 0fadd77

add config file

Browse files
Files changed (1) hide show
  1. config.yaml +142 -0
config.yaml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.1.1
2
+ seed_everything: 42
3
+
4
+ ### Trainer configuration
5
+ trainer:
6
+ accelerator: auto
7
+ strategy: auto
8
+ devices: auto
9
+ num_nodes: 1
10
+ # precision: 16-mixed
11
+ logger:
12
+ class_path: TensorBoardLogger
13
+ init_args:
14
+ save_dir: ./experiments
15
+ name: finetune_region
16
+ callbacks:
17
+ - class_path: RichProgressBar
18
+ - class_path: LearningRateMonitor
19
+ init_args:
20
+ logging_interval: epoch
21
+ - class_path: EarlyStopping
22
+ init_args:
23
+ monitor: val/loss
24
+ patience: 100
25
+ max_epochs: 300
26
+ check_val_every_n_epoch: 1
27
+ log_every_n_steps: 20
28
+ enable_checkpointing: true
29
+ default_root_dir: ./experiments
30
+
31
+ ### Data configuration
32
+ data:
33
+ class_path: GenericNonGeoPixelwiseRegressionDataModule
34
+ init_args:
35
+ batch_size: 64
36
+ num_workers: 8
37
+ train_transform:
38
+ - class_path: albumentations.HorizontalFlip
39
+ init_args:
40
+ p: 0.5
41
+ - class_path: albumentations.Rotate
42
+ init_args:
43
+ limit: 30
44
+ border_mode: 0 # cv2.BORDER_CONSTANT
45
+ value: 0
46
+ # mask_value: 1
47
+ p: 0.5
48
+ - class_path: ToTensorV2
49
+ # Specify all bands which are in the input data.
50
+ # -1 are placeholders for bands that are in the data but that we will discard
51
+ dataset_bands:
52
+ - -1
53
+ - BLUE
54
+ - GREEN
55
+ - RED
56
+ - NIR_NARROW
57
+ - SWIR_1
58
+ - SWIR_2
59
+ - -1
60
+ - -1
61
+ - -1
62
+ - -1
63
+ output_bands: #Specify the bands which are used from the input data.
64
+ - BLUE
65
+ - GREEN
66
+ - RED
67
+ - NIR_NARROW
68
+ - SWIR_1
69
+ - SWIR_2
70
+ rgb_indices:
71
+ - 2
72
+ - 1
73
+ - 0
74
+ # Directory roots to training, validation and test datasplits:
75
+ train_data_root: train_images
76
+ train_label_data_root: train_labels
77
+ val_data_root: val_images
78
+ val_label_data_root: val_labels
79
+ test_data_root: test_images
80
+ test_label_data_root: test_labels
81
+ means: # Mean value of the training dataset per band
82
+ - 547.36707
83
+ - 898.5121
84
+ - 1020.9082
85
+ - 2665.5352
86
+ - 2340.584
87
+ - 1610.1407
88
+ stds: # Standard deviation of the training dataset per band
89
+ - 411.4701
90
+ - 558.54065
91
+ - 815.94025
92
+ - 812.4403
93
+ - 1113.7145
94
+ - 1067.641
95
+ # Nodata value in label data
96
+ no_label_replace: -1
97
+ # Nodata value in the input data
98
+ no_data_replace: 0
99
+
100
+ ### Model configuration
101
+ model:
102
+ class_path: terratorch.tasks.PixelwiseRegressionTask
103
+ init_args:
104
+ model_args:
105
+ decoder: UperNetDecoder
106
+ pretrained: false
107
+ backbone: prithvi_swin_B
108
+ backbone_drop_path_rate: 0.3
109
+ decoder_channels: 32
110
+ in_channels: 6
111
+ bands:
112
+ - BLUE
113
+ - GREEN
114
+ - RED
115
+ - NIR_NARROW
116
+ - SWIR_1
117
+ - SWIR_2
118
+ num_frames: 1
119
+ head_dropout: 0.16194593880230534
120
+ head_final_act: torch.nn.ReLU
121
+ head_learned_upscale_layers: 2
122
+ loss: rmse
123
+ ignore_index: -1
124
+ freeze_backbone: false
125
+ freeze_decoder: false
126
+ model_factory: PrithviModelFactory
127
+ # uncomment this block for tiled inference
128
+ # tiled_inference_parameters:
129
+ # h_crop: 224
130
+ # h_stride: 192
131
+ # w_crop: 224
132
+ # w_stride: 192
133
+ # average_patches: true
134
+ optimizer:
135
+ class_path: torch.optim.AdamW
136
+ init_args:
137
+ lr: 0.00031406904191973693
138
+ weight_decay: 0.03283253068408954
139
+ lr_scheduler:
140
+ class_path: ReduceLROnPlateau
141
+ init_args:
142
+ monitor: val/loss