import torch import torch.nn as nn import pytorch_lightning as pl from torchmetrics import classification import wandb from matplotlib import pyplot as plt import numpy as np import matplotlib.ticker as ticker from matplotlib.colors import ListedColormap from huggingface_hub import PyTorchModelHubMixin from lion_pytorch import Lion import json from messis.prithvi import TemporalViTEncoder, ConvTransformerTokensToEmbeddingNeck, ConvTransformerTokensToEmbeddingBottleneckNeck def safe_shape(x): if isinstance(x, tuple): # loop through tuple shape_info = '(tuple) : ' for i in x: shape_info += str(i.shape) + ', ' return shape_info if isinstance(x, list): # loop through list shape_info = '(list) : ' for i in x: shape_info += str(i.shape) + ', ' return shape_info return x.shape class ConvModule(nn.Module): """ A simple convolutional module including Conv, BatchNorm, and ReLU layers. """ def __init__(self, in_channels, out_channels, kernel_size, padding, dilation): super(ConvModule, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) return self.relu(x) class HierarchicalFCNHead(nn.Module): """ Hierarchical FCN Head for semantic segmentation. """ def __init__(self, in_channels, out_channels, num_classes, num_convs=2, kernel_size=3, dilation=1, dropout_p=0.1, debug=False): super(HierarchicalFCNHead, self).__init__() self.debug = debug self.convs = nn.Sequential(*[ ConvModule( in_channels if i == 0 else out_channels, out_channels, kernel_size, padding=dilation * (kernel_size // 2), dilation=dilation ) for i in range(num_convs) ]) self.conv_seg = nn.Conv2d(out_channels, num_classes, kernel_size=1) self.dropout = nn.Dropout2d(p=dropout_p) def forward(self, x): if self.debug: print('HierarchicalFCNHead forward INP: ', safe_shape(x)) x = self.convs(x) features = self.dropout(x) output = self.conv_seg(features) if self.debug: print('HierarchicalFCNHead forward features OUT: ', safe_shape(features)) print('HierarchicalFCNHead forward output OUT: ', safe_shape(output)) return output, features class LabelRefinementHead(nn.Module): """ Similar to the label refinement module introduced in the ZueriCrop paper, this module refines the predictions for tier 3. It takes the raw predictions from head 1, head 2 and head 3 and refines them to produce the final prediction for tier 3. According to ZueriCrop, this helps with making the predictions more consistent across the different tiers. """ def __init__(self, input_channels, num_classes): super(LabelRefinementHead, self).__init__() self.cnn_layers = nn.Sequential( # 1x1 Convolutional layer nn.Conv2d(in_channels=input_channels, out_channels=128, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # 3x3 Convolutional layer nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Dropout(p=0.5), # Skip connection (implemented in forward method) # Another 3x3 Convolutional layer nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # 1x1 Convolutional layer to adjust the number of output channels to num_classes nn.Conv2d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0), nn.Dropout(p=0.5) ) def forward(self, x): # Apply initial conv layer y = self.cnn_layers[0:3](x) # Save for skip connection y_skip = y # Apply the next two conv layers y = self.cnn_layers[3:9](y) # Skip connection (element-wise addition) y = y + y_skip # Apply the last conv layer y = self.cnn_layers[9:](y) return y class HierarchicalClassifier(nn.Module): def __init__( self, heads_spec, dropout_p=0.1, img_size=256, patch_size=16, num_frames=3, bands=[0, 1, 2, 3, 4, 5], backbone_weights_path=None, freeze_backbone=True, use_bottleneck_neck=False, bottleneck_reduction_factor=4, loss_ignore_background=False, debug=False ): super(HierarchicalClassifier, self).__init__() self.embed_dim = 768 if num_frames % 3 != 0: raise ValueError("The number of frames must be a multiple of 3, it is currently: ", num_frames) self.num_frames = num_frames self.hp, self.wp = img_size // patch_size, img_size // patch_size self.heads_spec = heads_spec self.dropout_p = dropout_p self.loss_ignore_background = loss_ignore_background self.debug = debug if self.debug: print('hp and wp: ', self.hp, self.wp) self.prithvi = TemporalViTEncoder( img_size=img_size, patch_size=patch_size, num_frames=3, tubelet_size=1, in_chans=len(bands), embed_dim=self.embed_dim, depth=12, num_heads=8, mlp_ratio=4.0, norm_pix_loss=False, pretrained=backbone_weights_path, debug=self.debug ) # (Un)freeze the backbone for param in self.prithvi.parameters(): param.requires_grad = not freeze_backbone # Neck to transform the token-based output of the transformer into a spatial feature map number_of_necks = self.num_frames // 3 if use_bottleneck_neck: self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingBottleneckNeck( embed_dim=self.embed_dim * 3, output_embed_dim=self.embed_dim * 3, drop_cls_token=True, Hp=self.hp, Wp=self.wp, bottleneck_reduction_factor=bottleneck_reduction_factor ) for _ in range(number_of_necks)]) else: self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingNeck( embed_dim=self.embed_dim * 3, output_embed_dim=self.embed_dim * 3, drop_cls_token=True, Hp=self.hp, Wp=self.wp, ) for _ in range(number_of_necks)]) # Initialize heads and loss weights based on tiers self.heads = nn.ModuleDict() self.loss_weights = {} self.total_classes = 0 # Build HierarchicalFCNHeads head_count = 0 for head_name, head_info in self.heads_spec.items(): head_type = head_info['type'] num_classes = head_info['num_classes_to_predict'] loss_weight = head_info['loss_weight'] if head_type == 'HierarchicalFCNHead': num_classes = head_info['num_classes_to_predict'] loss_weight = head_info['loss_weight'] kernel_size = head_info.get('kernel_size', 3) num_convs = head_info.get('num_convs', 1) num_channels = head_info.get('num_channels', 256) self.total_classes += num_classes self.heads[head_name] = HierarchicalFCNHead( in_channels=(self.embed_dim * self.num_frames) if head_count == 0 else num_channels, out_channels=num_channels, num_classes=num_classes, num_convs=num_convs, kernel_size=kernel_size, dropout_p=self.dropout_p, debug=self.debug ) self.loss_weights[head_name] = loss_weight # NOTE: LabelRefinementHead must be the last in the dict, otherwise the total_classes will be incorrect if head_type == 'LabelRefinementHead': self.refinement_head = LabelRefinementHead(input_channels=self.total_classes, num_classes=num_classes) self.refinement_head_name = head_name self.loss_weights[head_name] = loss_weight head_count += 1 self.loss_func = nn.CrossEntropyLoss(ignore_index=-1) def forward(self, x): if self.debug: print(f"Input shape: {safe_shape(x)}") # torch.Size([4, 6, 9, 224, 224]) # Extract features from the base model if len(self.necks) == 1: features = [x] else: features = torch.chunk(x, len(self.necks), dim=2) features = [self.prithvi(x) for x in features] if self.debug: print(f"Features shape after base model: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 589, 768]), , (tuple) : torch.Size # Process through the neck features = [neck(feat_) for feat_, neck in zip(features, self.necks)] if self.debug: print(f"Features shape after neck: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 2304, 224, 224]), , (tuple) : torch.Size # Remove from tuple features = [feat[0] for feat in features] # stack the features to create a tensor of torch.Size([4, 6912, 224, 224]) features = torch.concatenate(features, dim=1) if self.debug: print(f"Features shape after removing tuple: {safe_shape(features)}") # torch.Size([4, 6912, 224, 224]) # Process through the heads outputs = {} for tier_name, head in self.heads.items(): output, features = head(features) outputs[tier_name] = output if self.debug: print(f"Features shape after {tier_name} head: {safe_shape(features)}") print(f"Output shape after {tier_name} head: {safe_shape(output)}") # Process through the classification refinement head output_concatenated = torch.cat(list(outputs.values()), dim=1) output_refinement_head = self.refinement_head(output_concatenated) outputs[self.refinement_head_name] = output_refinement_head return outputs def calculate_loss(self, outputs, targets): total_loss = 0 loss_per_head = {} for head_name, output in outputs.items(): if self.debug: print(f"Target index for {head_name}: {self.heads_spec[head_name]['target_idx']}") target = targets[self.heads_spec[head_name]['target_idx']] loss_target = target if self.loss_ignore_background: loss_target = target.clone() # Clone as original target needed in backward pass loss_target[loss_target == 0] = -1 # Set background class to ignore_index -1 for loss calculation loss = self.loss_func(output, loss_target) loss_per_head[f'{head_name}'] = loss total_loss += loss * self.loss_weights[head_name] return total_loss, loss_per_head class Messis(pl.LightningModule, PyTorchModelHubMixin): def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) self.model = HierarchicalClassifier( heads_spec=hparams['heads_spec'], dropout_p=hparams.get('dropout_p'), img_size=hparams.get('img_size'), patch_size=hparams.get('patch_size'), num_frames=hparams.get('num_frames'), bands=hparams.get('bands'), backbone_weights_path=hparams.get('backbone_weights_path'), freeze_backbone=hparams['freeze_backbone'], use_bottleneck_neck=hparams.get('use_bottleneck_neck'), bottleneck_reduction_factor=hparams.get('bottleneck_reduction_factor'), loss_ignore_background=hparams.get('loss_ignore_background'), debug=hparams.get('debug') ) def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): return self.__step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): return self.__step(batch, batch_idx, "val") def test_step(self, batch, batch_idx): return self.__step(batch, batch_idx, "test") def configure_optimizers(self): # select case on optimizer match self.hparams.get('optimizer', 'Adam'): case 'Adam': optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.get('lr', 1e-3)) case 'AdamW': optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.get('lr', 1e-3), weight_decay=self.hparams.get('optimizer_weight_decay', 0.01)) case 'SGD': optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.get('lr', 1e-3), momentum=self.hparams.get('optimizer_momentum', 0.9)) case 'Lion': # https://github.com/lucidrains/lion-pytorch | Typically lr 3-10 times lower than Adam and weight_decay 3-10 times higher optimizer = Lion(self.parameters(), lr=self.hparams.get('lr', 1e-4), weight_decay=self.hparams.get('optimizer_weight_decay', 0.1)) case _: raise ValueError(f"Optimizer {self.hparams.get('optimizer')} not supported") return optimizer def __step(self, batch, batch_idx, stage): inputs, targets = batch targets = torch.stack(targets[0]) outputs = self(inputs) loss, loss_per_head = self.model.calculate_loss(outputs, targets) loss_per_head_named = {f'{stage}_loss_{head}': loss_per_head[head] for head in loss_per_head} loss_proportions = { f'{stage}_loss_{head}_proportion': round(loss_per_head[head].item() / loss.item(), 2) for head in loss_per_head} loss_detail_dict = {**loss_per_head_named, **loss_proportions} if self.hparams.get('debug'): print(f"Step Inputs shape: {safe_shape(inputs)}") print(f"Step Targets shape: {safe_shape(targets)}") print(f"Step Outputs dict keys: {outputs.keys()}") # NOTE: All metrics other than loss are tracked by callbacks (LogMessisMetrics) self.log_dict({f'{stage}_loss': loss, **loss_detail_dict}, on_step=True, on_epoch=True, prog_bar=True, logger=True) return {'loss': loss, 'outputs': outputs} class LogConfusionMatrix(pl.Callback): def __init__(self, hparams, dataset_info_file, debug=False): super().__init__() assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams" self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)} self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None) self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None) assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True" assert self.final_head_name is not None, "No head found with 'is_final_head' set to True" self.tiers = list(self.tiers_dict.keys()) self.phases = ['train', 'val', 'test'] self.modes = ['pixelwise', 'majority'] self.debug = debug if debug: print(f"Final head identified as: {self.final_head_name}") print(f"LogConfusionMatrix Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}") with open(dataset_info_file, 'r') as f: self.dataset_info = json.load(f) # Initialize confusion matrices self.metrics_to_compute = ['confusion_matrix'] self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases} def __init_metrics(self, tier, phase): num_classes = self.tiers_dict[tier]['num_classes_to_predict'] confusion_matrix = classification.MulticlassConfusionMatrix(num_classes=num_classes) return { 'confusion_matrix': confusion_matrix } def setup(self, trainer, pl_module, stage=None): # Move all metrics to the correct device at the start of the training/validation device = pl_module.device for phase_metrics in self.metrics.values(): for tier_metrics in phase_metrics.values(): for mode_metrics in tier_metrics.values(): for metric in self.metrics_to_compute: mode_metrics[metric].to(device) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'train') def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'val') def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'test') def __update_confusion_matrices(self, trainer, pl_module, outputs, batch, batch_idx, phase): if trainer.sanity_checking: return targets = torch.stack(batch[1][0]) # (tiers, batch, H, W) outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W) field_ids = batch[1][1].permute(1, 0, 2, 3)[0] pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info) for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes): # Update all metrics assert len(preds) == len(targets), f"Number of predictions and targets do not match: {len(preds)} vs {len(targets)}" assert len(preds) == len(self.tiers), f"Number of predictions and tiers do not match: {len(preds)} vs {len(self.tiers)}" for pred, target, tier in zip(preds, targets, self.tiers): if self.debug: print(f"Updating confusion matrix for {phase} {tier} {mode}") metrics = self.metrics[phase][tier][mode] # flatten and remove background class if the mode is majority (such that the background class is not included in the confusion matrix) if mode == 'majority': pred = pred[target != 0] target = target[target != 0] metrics['confusion_matrix'].update(pred, target) @staticmethod def get_pixelwise_and_majority_outputs(refinement_head_outputs, tiers, field_ids, dataset_info): """ Get the pixelwise and majority predictions from the model outputs. The pixelwise tier predictions are derived from the refinement_head_outputs predictions. The majority last tier predictions are derived from the refinement_head_outputs. And then the majority lower-tier predictions are derived from the majority highest-tier predictions. Also sets the background to 0 for all field majority predictions (regardless of what the model predicts for the background class). As this is a classification task and not a segmentation task and the field boundaries are known beforehand and not of any interest. Args: refinement_head_outputs (torch.Tensor(batch, C, H, W)): The probability outputs from the model for the refined tier. tiers (list of str): List of tiers e.g. ['tier1', 'tier2', 'tier3']. field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction. dataset_info (dict): The dataset information. Returns: torch.Tensor(tiers, batch, H, W): The pixelwise predictions. torch.Tensor(tiers, batch, H, W): The majority predictions. """ # Assuming the highest tier is the last one in the list highest_tier = tiers[-1] pixelwise_highest_tier = torch.softmax(refinement_head_outputs, dim=1).argmax(dim=1) # (batch, H, W) majority_highest_tier = LogConfusionMatrix.get_field_majority_preds(refinement_head_outputs, field_ids) tier_mapping = {tier: dataset_info[f'{highest_tier}_to_{tier}'] for tier in tiers if tier != highest_tier} pixelwise_outputs = {highest_tier: pixelwise_highest_tier} majority_outputs = {highest_tier: majority_highest_tier} # Initialize pixelwise and majority outputs for each tier for tier in tiers: if tier != highest_tier: pixelwise_outputs[tier] = torch.zeros_like(pixelwise_highest_tier) majority_outputs[tier] = torch.zeros_like(majority_highest_tier) # Map the highest tier to lower tiers for i, mappings in enumerate(zip(*tier_mapping.values())): for j, tier in enumerate(tier_mapping.keys()): pixelwise_outputs[tier][pixelwise_highest_tier == i] = mappings[j] majority_outputs[tier][majority_highest_tier == i] = mappings[j] pixelwise_outputs_stacked = torch.stack([pixelwise_outputs[tier] for tier in tiers]) majority_outputs_stacked = torch.stack([majority_outputs[tier] for tier in tiers]) # Ensure these are tensors assert isinstance(pixelwise_outputs_stacked, torch.Tensor), "pixelwise_outputs_stacked is not a tensor" assert isinstance(majority_outputs_stacked, torch.Tensor), "majority_outputs_stacked is not a tensor" return pixelwise_outputs_stacked, majority_outputs_stacked @staticmethod def get_field_majority_preds(output, field_ids): """ Get the majority prediction for each field in the batch. The majority excludes the background class. Args: output (torch.Tensor(batch, C, H, W)): The probability outputs from the model (tier3_refined) field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction. Returns: torch.Tensor(batch, H, W): The majority predictions. """ # remove the background class pixelwise = torch.softmax(output[:, 1:, :, :], dim=1).argmax(dim=1) + 1 # (batch, H, W) majority_preds = torch.zeros_like(pixelwise) for batch in range(len(pixelwise)): field_ids_batch = field_ids[batch] for field_id in np.unique(field_ids_batch.cpu().numpy()): if field_id == 0: continue field_mask = field_ids_batch == field_id flattened_pred = pixelwise[batch][field_mask].view(-1) # Flatten the prediction flattened_pred = flattened_pred[flattened_pred != 0] # Exclude background class if len(flattened_pred) == 0: continue mode_pred, _ = torch.mode(flattened_pred) # Compute mode prediction majority_preds[batch][field_mask] = mode_pred.item() return majority_preds def on_train_epoch_end(self, trainer, pl_module): # Log and then reset the confusion matrices after training epoch self.__log_and_reset_confusion_matrices(trainer, pl_module, 'train') def on_validation_epoch_end(self, trainer, pl_module): # Log and then reset the confusion matrices after validation epoch self.__log_and_reset_confusion_matrices(trainer, pl_module, 'val') def on_test_epoch_end(self, trainer, pl_module): # Log and then reset the confusion matrices after test epoch self.__log_and_reset_confusion_matrices(trainer, pl_module, 'test') def __log_and_reset_confusion_matrices(self, trainer, pl_module, phase): if trainer.sanity_checking: return for tier in self.tiers: for mode in self.modes: metrics = self.metrics[phase][tier][mode] confusion_matrix = metrics['confusion_matrix'] if self.debug: print(f"Logging and resetting confusion matrix for {phase} {tier} Update count: {confusion_matrix._update_count}") matrix = confusion_matrix.compute() # columns are predictions and rows are targets # Calculate percentages matrix = matrix.float() row_sums = matrix.sum(dim=1, keepdim=True) matrix_percent = matrix / row_sums # Ensure percentages sum to 1 for each row or handle NaNs row_sum_check = matrix_percent.sum(dim=1) valid_rows = ~torch.isnan(row_sum_check) if valid_rows.any(): assert torch.allclose(row_sum_check[valid_rows], torch.ones_like(row_sum_check[valid_rows]), atol=1e-2), "Percentages do not sum to 1 for some valid rows" # Sort the matrix and labels by the total number of instances sorted_indices = row_sums.squeeze().argsort(descending=True) matrix_percent = matrix_percent[sorted_indices, :] # sort rows matrix_percent = matrix_percent[:, sorted_indices] # sort columns class_labels = [self.dataset_info[tier][i] for i in sorted_indices] row_sums_sorted = row_sums[sorted_indices] # Check for zero rows after sorting zero_rows = (row_sums_sorted == 0).squeeze() fig, ax = plt.subplots(figsize=(matrix.size(0), matrix.size(0)), dpi=140) ax.matshow(matrix_percent.cpu().numpy(), cmap='viridis') ax.xaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(1) + 1))) ax.yaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(0) + 1))) ax.set_xticklabels(class_labels + [''], rotation=45) ax.set_yticklabels(class_labels + ['']) # Add total number of instances to the y-axis labels y_labels = [f'{class_labels[i]} [n={int(row_sums_sorted[i].item()):,.0f}]'.replace(',', "'") for i in range(matrix.size(0))] ax.set_yticklabels(y_labels + ['']) ax.set_xlabel('Predictions') ax.set_ylabel('Targets') # Move x-axis label and ticks to the top ax.xaxis.set_label_position('top') ax.xaxis.set_ticks_position('top') fig.tight_layout() for i in range(matrix.size(0)): for j in range(matrix.size(1)): if zero_rows[i]: ax.text(j, i, 'N/A', ha='center', va='center', color='black') else: ax.text(j, i, f'{matrix_percent[i, j]:.2f}', ha='center', va='center', color='#F88379', weight='bold') # coral red trainer.logger.experiment.log({f"{phase}_{tier}_confusion_matrix_{mode}": wandb.Image(fig)}) plt.close() confusion_matrix.reset() class LogMessisMetrics(pl.Callback): def __init__(self, hparams, dataset_info_file, debug=False): super().__init__() assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams" self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)} self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None) self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None) assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True" assert self.final_head_name is not None, "No head found with 'is_final_head' set to True" self.tiers = list(self.tiers_dict.keys()) self.phases = ['train', 'val', 'test'] self.modes = ['pixelwise', 'majority'] self.debug = debug if debug: print(f"Last tier identified as: {self.last_tier_name}") print(f"Final head identified as: {self.final_head_name}") print(f"LogMessisMetrics Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}") with open(dataset_info_file, 'r') as f: self.dataset_info = json.load(f) # Initialize metrics self.metrics_to_compute = ['accuracy', 'weighted_accuracy', 'precision', 'weighted_precision', 'recall', 'weighted_recall' ,'f1', 'weighted_f1', 'cohen_kappa'] self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases} self.images_to_log = {phase: {mode: None for mode in self.modes} for phase in self.phases} self.images_to_log_targets = {phase: None for phase in self.phases} self.field_ids_to_log_targets = {phase: None for phase in self.phases} self.inputs_to_log = {phase: None for phase in self.phases} def __init_metrics(self, tier, phase): num_classes = self.tiers_dict[tier]['num_classes_to_predict'] accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='macro') weighted_accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='weighted') per_class_accuracies = { class_index: classification.BinaryAccuracy() for class_index in range(num_classes) } precision = classification.MulticlassPrecision(num_classes=num_classes, average='macro') weighted_precision = classification.MulticlassPrecision(num_classes=num_classes, average='weighted') recall = classification.MulticlassRecall(num_classes=num_classes, average='macro') weighted_recall = classification.MulticlassRecall(num_classes=num_classes, average='weighted') f1 = classification.MulticlassF1Score(num_classes=num_classes, average='macro') weighted_f1 = classification.MulticlassF1Score(num_classes=num_classes, average='weighted') cohen_kappa = classification.MulticlassCohenKappa(num_classes=num_classes) return { 'accuracy': accuracy, 'weighted_accuracy': weighted_accuracy, 'per_class_accuracies': per_class_accuracies, 'precision': precision, 'weighted_precision': weighted_precision, 'recall': recall, 'weighted_recall': weighted_recall, 'f1': f1, 'weighted_f1': weighted_f1, 'cohen_kappa': cohen_kappa } def setup(self, trainer, pl_module, stage=None): # Move all metrics to the correct device at the start of the training/validation device = pl_module.device for phase_metrics in self.metrics.values(): for tier_metrics in phase_metrics.values(): for mode_metrics in tier_metrics.values(): for metric in self.metrics_to_compute: mode_metrics[metric].to(device) for class_accuracy in mode_metrics['per_class_accuracies'].values(): class_accuracy.to(device) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'train') def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'val') def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'test') def __on_batch_end(self, trainer: pl.Trainer, pl_module, outputs, batch, batch_idx, phase): if trainer.sanity_checking: return if self.debug: print(f"{phase} batch ended. Updating metrics...") targets = torch.stack(batch[1][0]) # (tiers, batch, H, W) outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W) field_ids = batch[1][1].permute(1, 0, 2, 3)[0] pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info) for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes): # Update all metrics assert preds.shape == targets.shape, f"Shapes of predictions and targets do not match: {preds.shape} vs {targets.shape}" assert preds.shape[0] == len(self.tiers), f"Number of tiers in predictions and tiers do not match: {preds.shape[0]} vs {len(self.tiers)}" self.images_to_log[phase][mode] = preds[-1] for pred, target, tier in zip(preds, targets, self.tiers): # flatten and remove background class if the mode is majority (such that the background class is not considered in the metrics) if mode == 'majority': pred = pred[target != 0] target = target[target != 0] metrics = self.metrics[phase][tier][mode] for metric in self.metrics_to_compute: metrics[metric].update(pred, target) if self.debug: print(f"{phase} {tier} {mode} {metric} updated. Update count: {metrics[metric]._update_count}") self.__update_per_class_metrics(pred, target, metrics['per_class_accuracies']) self.images_to_log_targets[phase] = targets[-1] self.field_ids_to_log_targets[phase] = field_ids self.inputs_to_log[phase] = batch[0] def __update_per_class_metrics(self, preds, targets, per_class_accuracies): for class_index, class_accuracy in per_class_accuracies.items(): if not (targets == class_index).any(): continue if class_index == 0: # Mask out non-background elements for background class (0) class_mask = targets != 0 else: # Mask out background elements for other classes class_mask = targets == 0 preds_fields = preds[~class_mask] targets_fields = targets[~class_mask] # Prepare for binary classification (needs to be float) preds_class = (preds_fields == class_index).float() targets_class = (targets_fields == class_index).float() class_accuracy.update(preds_class, targets_class) if self.debug: print(f"Shape of preds_fields: {preds_fields.shape}") print(f"Shape of targets_fields: {targets_fields.shape}") print(f"Unique values in preds_fields: {torch.unique(preds_fields)}") print(f"Unique values in targets_fields: {torch.unique(targets_fields)}") print(f"Per-class metrics for class {class_index} updated. Update count: {per_class_accuracies[class_index]._update_count}") def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): self.__on_epoch_end(trainer, pl_module, 'train') def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): self.__on_epoch_end(trainer, pl_module, 'val') def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): self.__on_epoch_end(trainer, pl_module, 'test') def __on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, phase): if trainer.sanity_checking: return # Skip during sanity check (avoid warning about metric compute being called before update) for tier in self.tiers: for mode in self.modes: metrics = self.metrics[phase][tier][mode] # Calculate and reset in tier: Accuracy, WeightedAccuracy, Precision, Recall, F1, Cohen's Kappa metrics_dict = {metric: metrics[metric].compute() for metric in self.metrics_to_compute} pl_module.log_dict({f"{phase}_{metric}_{tier}_{mode}": v for metric, v in metrics_dict.items()}, on_step=False, on_epoch=True) for metric in self.metrics_to_compute: metrics[metric].reset() # Per-class metrics # NOTE: Some literature reports "per class accuracy" but what they actually mean is "per class recall". # Using the accuracy formula per class has no value in our imbalanced multi-class setting (TN's inflate scores!) # We calculate all 4 metrics. This allows us to calculate any macro/micro score later if needed. class_metrics = [] class_names_mapping = self.dataset_info[tier.split('_')[0] if '_refined' in tier else tier] for class_index, class_accuracy in metrics['per_class_accuracies'].items(): if class_accuracy._update_count == 0: continue # Skip if no updates have been made tp, tn, fp, fn = class_accuracy.tp, class_accuracy.tn, class_accuracy.fp, class_accuracy.fn recall = (tp / (tp + fn)).item() if tp + fn > 0 else 0 precision = (tp / (tp + fp)).item() if tp + fp > 0 else 0 f1 = (2 * (precision * recall) / (precision + recall)) if precision + recall > 0 else 0 n_of_class = (tp + fn).item() class_metrics.append([class_index, class_names_mapping[class_index], precision, recall, f1, class_accuracy.compute().item(), n_of_class]) class_accuracy.reset() wandb_table = wandb.Table(data=class_metrics, columns=["Class Index", "Class Name", "Precision", "Recall", "F1", "Accuracy", "N"]) trainer.logger.experiment.log({f"{phase}_per_class_metrics_{tier}_{mode}": wandb_table}) # use the same n_classes for all images, such that they are comparable n_classes = max([ torch.max(self.images_to_log_targets[phase]), torch.max(self.images_to_log[phase]["majority"]), torch.max(self.images_to_log[phase]["pixelwise"]) ]) images = [LogMessisMetrics.process_images(self.images_to_log[phase][mode], n_classes) for mode in self.modes] images.append(LogMessisMetrics.create_positive_negative_image(self.images_to_log[phase]["majority"], self.images_to_log_targets[phase])) images.append(LogMessisMetrics.process_images(self.images_to_log_targets[phase], n_classes)) images.append(LogMessisMetrics.process_images(self.field_ids_to_log_targets[phase].cpu())) examples = [] for i in range(len(images[0])): example = np.concatenate([img[i] for img in images], axis=0) examples.append(wandb.Image(example, caption=f"From Top to Bottom: {self.modes[0]}, {self.modes[1]}, right/wrong classifications, target, fields")) trainer.logger.experiment.log({f"{phase}_examples": examples}) # Log segmentation masks batch_input_data = self.inputs_to_log[phase].cpu() # shape [BS, 6, N_TIMESTEPS, 224, 224] ground_truth_masks = self.images_to_log_targets[phase].cpu().numpy() pixel_wise_masks = self.images_to_log[phase]["pixelwise"].cpu().numpy() field_majority_masks = self.images_to_log[phase]["majority"].cpu().numpy() correctness_masks = self.create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks) class_labels = {idx: name for idx, name in enumerate(self.dataset_info[self.last_tier_name])} segmentation_masks = [] for input_data, ground_truth_mask, pixel_wise_mask, field_majority_mask, correctness_mask in zip(batch_input_data, ground_truth_masks, pixel_wise_masks, field_majority_masks, correctness_masks): middle_timestep_index = input_data.shape[1] // 2 # Get the middle timestamp index gamma = 2.5 # Gamma for brightness adjustment rgb_image = input_data[:3, middle_timestep_index, :, :].permute(1, 2, 0).numpy() # Shape [224, 224, 3] rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min()) rgb_image = np.power(rgb_image, 1.0 / gamma) rgb_image = (rgb_image * 255).astype(np.uint8) mask_img = wandb.Image( rgb_image, masks={ "predictions_pixel_wise": {"mask_data": pixel_wise_mask, "class_labels": class_labels}, "predictions_field_majority": {"mask_data": field_majority_mask, "class_labels": class_labels}, "ground_truth": {"mask_data": ground_truth_mask, "class_labels": class_labels}, "correctness": {"mask_data": correctness_mask, "class_labels": { 0: "Background", 1: "Wrong", 2: "Right" }}, }, ) segmentation_masks.append(mask_img) trainer.logger.experiment.log({f"{phase}_segmentation_mask": segmentation_masks}) if self.debug: print(f"{phase} epoch ended. Logging & resetting metrics...", trainer.sanity_checking) @staticmethod def create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks): """ Create a tensor that shows the positive and negative classifications of the model. Args: field_majority_masks (np.ndarray): The field majority masks generated by the model. ground_truth_masks (np.ndarray): The ground truth masks. Returns: np.ndarray: An array with values: - 0 where the target is 0, - 2 where the prediction matches the target, - 1 where the prediction does not match the target. """ correctness_mask = np.zeros_like(ground_truth_masks, dtype=int) matches = (field_majority_masks == ground_truth_masks) & (ground_truth_masks != 0) correctness_mask[matches] = 2 mismatches = (field_majority_masks != ground_truth_masks) & (ground_truth_masks != 0) correctness_mask[mismatches] = 1 return correctness_mask @staticmethod def create_positive_negative_image(generated_images, target_images): """ Create an image that shows the positive and negative classifications of the model. Args: generated_images (torch.Tensor): The images generated by the model. target_images (torch.Tensor): The target images. Returns: list: A list of processed images. """ classification_masks = generated_images == target_images processed_imgs = [] for mask, target in zip(classification_masks, target_images): # color the background white, right classifications green, wrong classifications red colored_img = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8) mask = mask.bool() # Convert to boolean tensor colored_img[mask] = torch.tensor([0, 255, 0], dtype=torch.uint8) colored_img[~mask] = torch.tensor([255, 0, 0], dtype=torch.uint8) colored_img[target == 0] = torch.tensor([0, 0, 0], dtype=torch.uint8) processed_imgs.append(colored_img.cpu()) return processed_imgs @staticmethod def process_images(imgs, max=None): """ Process a batch of images to be logged on wandb. Args: imgs (torch.Tensor): A batch of images with shape (B, H, W) to be processed. max (float, optional): The maximum value to normalize the images. Defaults to None. If None, the maximum value in the batch is used. """ if max is None: max = np.max(imgs.cpu().numpy()) normalized_img = imgs / max processed_imgs = [] for img in normalized_img.cpu().numpy(): if max < 60: cmap = ListedColormap(plt.get_cmap('tab20').colors + plt.get_cmap('tab20b').colors + plt.get_cmap('tab20c').colors) else: cmap = plt.get_cmap('viridis') colored_img = cmap(img) colored_img[img == 0] = [0, 0, 0, 1] colored_img_uint8 = (colored_img[:, :, :3] * 255).astype(np.uint8) processed_imgs.append(colored_img_uint8) return processed_imgs