Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
from torchmetrics import classification | |
import wandb | |
from matplotlib import pyplot as plt | |
import numpy as np | |
import matplotlib.ticker as ticker | |
from matplotlib.colors import ListedColormap | |
from huggingface_hub import PyTorchModelHubMixin | |
from lion_pytorch import Lion | |
import json | |
from messis.prithvi import TemporalViTEncoder, ConvTransformerTokensToEmbeddingNeck, ConvTransformerTokensToEmbeddingBottleneckNeck | |
def safe_shape(x): | |
if isinstance(x, tuple): | |
# loop through tuple | |
shape_info = '(tuple) : ' | |
for i in x: | |
shape_info += str(i.shape) + ', ' | |
return shape_info | |
if isinstance(x, list): | |
# loop through list | |
shape_info = '(list) : ' | |
for i in x: | |
shape_info += str(i.shape) + ', ' | |
return shape_info | |
return x.shape | |
class ConvModule(nn.Module): | |
""" | |
A simple convolutional module including Conv, BatchNorm, and ReLU layers. | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation): | |
super(ConvModule, self).__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False) | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
return self.relu(x) | |
class HierarchicalFCNHead(nn.Module): | |
""" | |
Hierarchical FCN Head for semantic segmentation. | |
""" | |
def __init__(self, in_channels, out_channels, num_classes, num_convs=2, kernel_size=3, dilation=1, dropout_p=0.1, debug=False): | |
super(HierarchicalFCNHead, self).__init__() | |
self.debug = debug | |
self.convs = nn.Sequential(*[ | |
ConvModule( | |
in_channels if i == 0 else out_channels, | |
out_channels, | |
kernel_size, | |
padding=dilation * (kernel_size // 2), | |
dilation=dilation | |
) for i in range(num_convs) | |
]) | |
self.conv_seg = nn.Conv2d(out_channels, num_classes, kernel_size=1) | |
self.dropout = nn.Dropout2d(p=dropout_p) | |
def forward(self, x): | |
if self.debug: | |
print('HierarchicalFCNHead forward INP: ', safe_shape(x)) | |
x = self.convs(x) | |
features = self.dropout(x) | |
output = self.conv_seg(features) | |
if self.debug: | |
print('HierarchicalFCNHead forward features OUT: ', safe_shape(features)) | |
print('HierarchicalFCNHead forward output OUT: ', safe_shape(output)) | |
return output, features | |
class LabelRefinementHead(nn.Module): | |
""" | |
Similar to the label refinement module introduced in the ZueriCrop paper, this module refines the predictions for tier 3. | |
It takes the raw predictions from head 1, head 2 and head 3 and refines them to produce the final prediction for tier 3. | |
According to ZueriCrop, this helps with making the predictions more consistent across the different tiers. | |
""" | |
def __init__(self, input_channels, num_classes): | |
super(LabelRefinementHead, self).__init__() | |
self.cnn_layers = nn.Sequential( | |
# 1x1 Convolutional layer | |
nn.Conv2d(in_channels=input_channels, out_channels=128, kernel_size=1, stride=1, padding=0), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True), | |
# 3x3 Convolutional layer | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True), | |
nn.Dropout(p=0.5), | |
# Skip connection (implemented in forward method) | |
# Another 3x3 Convolutional layer | |
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True), | |
# 1x1 Convolutional layer to adjust the number of output channels to num_classes | |
nn.Conv2d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0), | |
nn.Dropout(p=0.5) | |
) | |
def forward(self, x): | |
# Apply initial conv layer | |
y = self.cnn_layers[0:3](x) | |
# Save for skip connection | |
y_skip = y | |
# Apply the next two conv layers | |
y = self.cnn_layers[3:9](y) | |
# Skip connection (element-wise addition) | |
y = y + y_skip | |
# Apply the last conv layer | |
y = self.cnn_layers[9:](y) | |
return y | |
class HierarchicalClassifier(nn.Module): | |
def __init__( | |
self, | |
heads_spec, | |
dropout_p=0.1, | |
img_size=256, | |
patch_size=16, | |
num_frames=3, | |
bands=[0, 1, 2, 3, 4, 5], | |
backbone_weights_path=None, | |
freeze_backbone=True, | |
use_bottleneck_neck=False, | |
bottleneck_reduction_factor=4, | |
loss_ignore_background=False, | |
debug=False | |
): | |
super(HierarchicalClassifier, self).__init__() | |
self.embed_dim = 768 | |
if num_frames % 3 != 0: | |
raise ValueError("The number of frames must be a multiple of 3, it is currently: ", num_frames) | |
self.num_frames = num_frames | |
self.hp, self.wp = img_size // patch_size, img_size // patch_size | |
self.heads_spec = heads_spec | |
self.dropout_p = dropout_p | |
self.loss_ignore_background = loss_ignore_background | |
self.debug = debug | |
if self.debug: | |
print('hp and wp: ', self.hp, self.wp) | |
self.prithvi = TemporalViTEncoder( | |
img_size=img_size, | |
patch_size=patch_size, | |
num_frames=3, | |
tubelet_size=1, | |
in_chans=len(bands), | |
embed_dim=self.embed_dim, | |
depth=12, | |
num_heads=8, | |
mlp_ratio=4.0, | |
norm_pix_loss=False, | |
pretrained=backbone_weights_path, | |
debug=self.debug | |
) | |
# (Un)freeze the backbone | |
for param in self.prithvi.parameters(): | |
param.requires_grad = not freeze_backbone | |
# Neck to transform the token-based output of the transformer into a spatial feature map | |
number_of_necks = self.num_frames // 3 | |
if use_bottleneck_neck: | |
self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingBottleneckNeck( | |
embed_dim=self.embed_dim * 3, | |
output_embed_dim=self.embed_dim * 3, | |
drop_cls_token=True, | |
Hp=self.hp, | |
Wp=self.wp, | |
bottleneck_reduction_factor=bottleneck_reduction_factor | |
) for _ in range(number_of_necks)]) | |
else: | |
self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingNeck( | |
embed_dim=self.embed_dim * 3, | |
output_embed_dim=self.embed_dim * 3, | |
drop_cls_token=True, | |
Hp=self.hp, | |
Wp=self.wp, | |
) for _ in range(number_of_necks)]) | |
# Initialize heads and loss weights based on tiers | |
self.heads = nn.ModuleDict() | |
self.loss_weights = {} | |
self.total_classes = 0 | |
# Build HierarchicalFCNHeads | |
head_count = 0 | |
for head_name, head_info in self.heads_spec.items(): | |
head_type = head_info['type'] | |
num_classes = head_info['num_classes_to_predict'] | |
loss_weight = head_info['loss_weight'] | |
if head_type == 'HierarchicalFCNHead': | |
num_classes = head_info['num_classes_to_predict'] | |
loss_weight = head_info['loss_weight'] | |
kernel_size = head_info.get('kernel_size', 3) | |
num_convs = head_info.get('num_convs', 1) | |
num_channels = head_info.get('num_channels', 256) | |
self.total_classes += num_classes | |
self.heads[head_name] = HierarchicalFCNHead( | |
in_channels=(self.embed_dim * self.num_frames) if head_count == 0 else num_channels, | |
out_channels=num_channels, | |
num_classes=num_classes, | |
num_convs=num_convs, | |
kernel_size=kernel_size, | |
dropout_p=self.dropout_p, | |
debug=self.debug | |
) | |
self.loss_weights[head_name] = loss_weight | |
# NOTE: LabelRefinementHead must be the last in the dict, otherwise the total_classes will be incorrect | |
if head_type == 'LabelRefinementHead': | |
self.refinement_head = LabelRefinementHead(input_channels=self.total_classes, num_classes=num_classes) | |
self.refinement_head_name = head_name | |
self.loss_weights[head_name] = loss_weight | |
head_count += 1 | |
self.loss_func = nn.CrossEntropyLoss(ignore_index=-1) | |
def forward(self, x): | |
if self.debug: | |
print(f"Input shape: {safe_shape(x)}") # torch.Size([4, 6, 9, 224, 224]) | |
# Extract features from the base model | |
if len(self.necks) == 1: | |
features = [x] | |
else: | |
features = torch.chunk(x, len(self.necks), dim=2) | |
features = [self.prithvi(x) for x in features] | |
if self.debug: | |
print(f"Features shape after base model: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 589, 768]), , (tuple) : torch.Size | |
# Process through the neck | |
features = [neck(feat_) for feat_, neck in zip(features, self.necks)] | |
if self.debug: | |
print(f"Features shape after neck: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 2304, 224, 224]), , (tuple) : torch.Size | |
# Remove from tuple | |
features = [feat[0] for feat in features] | |
# stack the features to create a tensor of torch.Size([4, 6912, 224, 224]) | |
features = torch.concatenate(features, dim=1) | |
if self.debug: | |
print(f"Features shape after removing tuple: {safe_shape(features)}") # torch.Size([4, 6912, 224, 224]) | |
# Process through the heads | |
outputs = {} | |
for tier_name, head in self.heads.items(): | |
output, features = head(features) | |
outputs[tier_name] = output | |
if self.debug: | |
print(f"Features shape after {tier_name} head: {safe_shape(features)}") | |
print(f"Output shape after {tier_name} head: {safe_shape(output)}") | |
# Process through the classification refinement head | |
output_concatenated = torch.cat(list(outputs.values()), dim=1) | |
output_refinement_head = self.refinement_head(output_concatenated) | |
outputs[self.refinement_head_name] = output_refinement_head | |
return outputs | |
def calculate_loss(self, outputs, targets): | |
total_loss = 0 | |
loss_per_head = {} | |
for head_name, output in outputs.items(): | |
if self.debug: | |
print(f"Target index for {head_name}: {self.heads_spec[head_name]['target_idx']}") | |
target = targets[self.heads_spec[head_name]['target_idx']] | |
loss_target = target | |
if self.loss_ignore_background: | |
loss_target = target.clone() # Clone as original target needed in backward pass | |
loss_target[loss_target == 0] = -1 # Set background class to ignore_index -1 for loss calculation | |
loss = self.loss_func(output, loss_target) | |
loss_per_head[f'{head_name}'] = loss | |
total_loss += loss * self.loss_weights[head_name] | |
return total_loss, loss_per_head | |
class Messis(pl.LightningModule, PyTorchModelHubMixin): | |
def __init__(self, hparams): | |
super().__init__() | |
self.save_hyperparameters(hparams) | |
self.model = HierarchicalClassifier( | |
heads_spec=hparams['heads_spec'], | |
dropout_p=hparams.get('dropout_p'), | |
img_size=hparams.get('img_size'), | |
patch_size=hparams.get('patch_size'), | |
num_frames=hparams.get('num_frames'), | |
bands=hparams.get('bands'), | |
backbone_weights_path=hparams.get('backbone_weights_path'), | |
freeze_backbone=hparams['freeze_backbone'], | |
use_bottleneck_neck=hparams.get('use_bottleneck_neck'), | |
bottleneck_reduction_factor=hparams.get('bottleneck_reduction_factor'), | |
loss_ignore_background=hparams.get('loss_ignore_background'), | |
debug=hparams.get('debug') | |
) | |
def forward(self, x): | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
return self.__step(batch, batch_idx, "train") | |
def validation_step(self, batch, batch_idx): | |
return self.__step(batch, batch_idx, "val") | |
def test_step(self, batch, batch_idx): | |
return self.__step(batch, batch_idx, "test") | |
def configure_optimizers(self): | |
# select case on optimizer | |
match self.hparams.get('optimizer', 'Adam'): | |
case 'Adam': | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.get('lr', 1e-3)) | |
case 'AdamW': | |
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.get('lr', 1e-3), weight_decay=self.hparams.get('optimizer_weight_decay', 0.01)) | |
case 'SGD': | |
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.get('lr', 1e-3), momentum=self.hparams.get('optimizer_momentum', 0.9)) | |
case 'Lion': | |
# https://github.com/lucidrains/lion-pytorch | Typically lr 3-10 times lower than Adam and weight_decay 3-10 times higher | |
optimizer = Lion(self.parameters(), lr=self.hparams.get('lr', 1e-4), weight_decay=self.hparams.get('optimizer_weight_decay', 0.1)) | |
case _: | |
raise ValueError(f"Optimizer {self.hparams.get('optimizer')} not supported") | |
return optimizer | |
def __step(self, batch, batch_idx, stage): | |
inputs, targets = batch | |
targets = torch.stack(targets[0]) | |
outputs = self(inputs) | |
loss, loss_per_head = self.model.calculate_loss(outputs, targets) | |
loss_per_head_named = {f'{stage}_loss_{head}': loss_per_head[head] for head in loss_per_head} | |
loss_proportions = { f'{stage}_loss_{head}_proportion': round(loss_per_head[head].item() / loss.item(), 2) for head in loss_per_head} | |
loss_detail_dict = {**loss_per_head_named, **loss_proportions} | |
if self.hparams.get('debug'): | |
print(f"Step Inputs shape: {safe_shape(inputs)}") | |
print(f"Step Targets shape: {safe_shape(targets)}") | |
print(f"Step Outputs dict keys: {outputs.keys()}") | |
# NOTE: All metrics other than loss are tracked by callbacks (LogMessisMetrics) | |
self.log_dict({f'{stage}_loss': loss, **loss_detail_dict}, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
return {'loss': loss, 'outputs': outputs} | |
class LogConfusionMatrix(pl.Callback): | |
def __init__(self, hparams, dataset_info_file, debug=False): | |
super().__init__() | |
assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams" | |
self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)} | |
self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None) | |
self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None) | |
assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True" | |
assert self.final_head_name is not None, "No head found with 'is_final_head' set to True" | |
self.tiers = list(self.tiers_dict.keys()) | |
self.phases = ['train', 'val', 'test'] | |
self.modes = ['pixelwise', 'majority'] | |
self.debug = debug | |
if debug: | |
print(f"Final head identified as: {self.final_head_name}") | |
print(f"LogConfusionMatrix Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}") | |
with open(dataset_info_file, 'r') as f: | |
self.dataset_info = json.load(f) | |
# Initialize confusion matrices | |
self.metrics_to_compute = ['confusion_matrix'] | |
self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases} | |
def __init_metrics(self, tier, phase): | |
num_classes = self.tiers_dict[tier]['num_classes_to_predict'] | |
confusion_matrix = classification.MulticlassConfusionMatrix(num_classes=num_classes) | |
return { | |
'confusion_matrix': confusion_matrix | |
} | |
def setup(self, trainer, pl_module, stage=None): | |
# Move all metrics to the correct device at the start of the training/validation | |
device = pl_module.device | |
for phase_metrics in self.metrics.values(): | |
for tier_metrics in phase_metrics.values(): | |
for mode_metrics in tier_metrics.values(): | |
for metric in self.metrics_to_compute: | |
mode_metrics[metric].to(device) | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'train') | |
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'val') | |
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'test') | |
def __update_confusion_matrices(self, trainer, pl_module, outputs, batch, batch_idx, phase): | |
if trainer.sanity_checking: | |
return | |
targets = torch.stack(batch[1][0]) # (tiers, batch, H, W) | |
outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W) | |
field_ids = batch[1][1].permute(1, 0, 2, 3)[0] | |
pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info) | |
for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes): | |
# Update all metrics | |
assert len(preds) == len(targets), f"Number of predictions and targets do not match: {len(preds)} vs {len(targets)}" | |
assert len(preds) == len(self.tiers), f"Number of predictions and tiers do not match: {len(preds)} vs {len(self.tiers)}" | |
for pred, target, tier in zip(preds, targets, self.tiers): | |
if self.debug: | |
print(f"Updating confusion matrix for {phase} {tier} {mode}") | |
metrics = self.metrics[phase][tier][mode] | |
# flatten and remove background class if the mode is majority (such that the background class is not included in the confusion matrix) | |
if mode == 'majority': | |
pred = pred[target != 0] | |
target = target[target != 0] | |
metrics['confusion_matrix'].update(pred, target) | |
def get_pixelwise_and_majority_outputs(refinement_head_outputs, tiers, field_ids, dataset_info): | |
""" | |
Get the pixelwise and majority predictions from the model outputs. | |
The pixelwise tier predictions are derived from the refinement_head_outputs predictions. | |
The majority last tier predictions are derived from the refinement_head_outputs. And then the majority lower-tier predictions are derived from the majority highest-tier predictions. | |
Also sets the background to 0 for all field majority predictions (regardless of what the model predicts for the background class). | |
As this is a classification task and not a segmentation task and the field boundaries are known beforehand and not of any interest. | |
Args: | |
refinement_head_outputs (torch.Tensor(batch, C, H, W)): The probability outputs from the model for the refined tier. | |
tiers (list of str): List of tiers e.g. ['tier1', 'tier2', 'tier3']. | |
field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction. | |
dataset_info (dict): The dataset information. | |
Returns: | |
torch.Tensor(tiers, batch, H, W): The pixelwise predictions. | |
torch.Tensor(tiers, batch, H, W): The majority predictions. | |
""" | |
# Assuming the highest tier is the last one in the list | |
highest_tier = tiers[-1] | |
pixelwise_highest_tier = torch.softmax(refinement_head_outputs, dim=1).argmax(dim=1) # (batch, H, W) | |
majority_highest_tier = LogConfusionMatrix.get_field_majority_preds(refinement_head_outputs, field_ids) | |
tier_mapping = {tier: dataset_info[f'{highest_tier}_to_{tier}'] for tier in tiers if tier != highest_tier} | |
pixelwise_outputs = {highest_tier: pixelwise_highest_tier} | |
majority_outputs = {highest_tier: majority_highest_tier} | |
# Initialize pixelwise and majority outputs for each tier | |
for tier in tiers: | |
if tier != highest_tier: | |
pixelwise_outputs[tier] = torch.zeros_like(pixelwise_highest_tier) | |
majority_outputs[tier] = torch.zeros_like(majority_highest_tier) | |
# Map the highest tier to lower tiers | |
for i, mappings in enumerate(zip(*tier_mapping.values())): | |
for j, tier in enumerate(tier_mapping.keys()): | |
pixelwise_outputs[tier][pixelwise_highest_tier == i] = mappings[j] | |
majority_outputs[tier][majority_highest_tier == i] = mappings[j] | |
pixelwise_outputs_stacked = torch.stack([pixelwise_outputs[tier] for tier in tiers]) | |
majority_outputs_stacked = torch.stack([majority_outputs[tier] for tier in tiers]) | |
# Ensure these are tensors | |
assert isinstance(pixelwise_outputs_stacked, torch.Tensor), "pixelwise_outputs_stacked is not a tensor" | |
assert isinstance(majority_outputs_stacked, torch.Tensor), "majority_outputs_stacked is not a tensor" | |
return pixelwise_outputs_stacked, majority_outputs_stacked | |
def get_field_majority_preds(output, field_ids): | |
""" | |
Get the majority prediction for each field in the batch. The majority excludes the background class. | |
Args: | |
output (torch.Tensor(batch, C, H, W)): The probability outputs from the model (tier3_refined) | |
field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction. | |
Returns: | |
torch.Tensor(batch, H, W): The majority predictions. | |
""" | |
# remove the background class | |
pixelwise = torch.softmax(output[:, 1:, :, :], dim=1).argmax(dim=1) + 1 # (batch, H, W) | |
majority_preds = torch.zeros_like(pixelwise) | |
for batch in range(len(pixelwise)): | |
field_ids_batch = field_ids[batch] | |
for field_id in np.unique(field_ids_batch.cpu().numpy()): | |
if field_id == 0: | |
continue | |
field_mask = field_ids_batch == field_id | |
flattened_pred = pixelwise[batch][field_mask].view(-1) # Flatten the prediction | |
flattened_pred = flattened_pred[flattened_pred != 0] # Exclude background class | |
if len(flattened_pred) == 0: | |
continue | |
mode_pred, _ = torch.mode(flattened_pred) # Compute mode prediction | |
majority_preds[batch][field_mask] = mode_pred.item() | |
return majority_preds | |
def on_train_epoch_end(self, trainer, pl_module): | |
# Log and then reset the confusion matrices after training epoch | |
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'train') | |
def on_validation_epoch_end(self, trainer, pl_module): | |
# Log and then reset the confusion matrices after validation epoch | |
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'val') | |
def on_test_epoch_end(self, trainer, pl_module): | |
# Log and then reset the confusion matrices after test epoch | |
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'test') | |
def __log_and_reset_confusion_matrices(self, trainer, pl_module, phase): | |
if trainer.sanity_checking: | |
return | |
for tier in self.tiers: | |
for mode in self.modes: | |
metrics = self.metrics[phase][tier][mode] | |
confusion_matrix = metrics['confusion_matrix'] | |
if self.debug: | |
print(f"Logging and resetting confusion matrix for {phase} {tier} Update count: {confusion_matrix._update_count}") | |
matrix = confusion_matrix.compute() # columns are predictions and rows are targets | |
# Calculate percentages | |
matrix = matrix.float() | |
row_sums = matrix.sum(dim=1, keepdim=True) | |
matrix_percent = matrix / row_sums | |
# Ensure percentages sum to 1 for each row or handle NaNs | |
row_sum_check = matrix_percent.sum(dim=1) | |
valid_rows = ~torch.isnan(row_sum_check) | |
if valid_rows.any(): | |
assert torch.allclose(row_sum_check[valid_rows], torch.ones_like(row_sum_check[valid_rows]), atol=1e-2), "Percentages do not sum to 1 for some valid rows" | |
# Sort the matrix and labels by the total number of instances | |
sorted_indices = row_sums.squeeze().argsort(descending=True) | |
matrix_percent = matrix_percent[sorted_indices, :] # sort rows | |
matrix_percent = matrix_percent[:, sorted_indices] # sort columns | |
class_labels = [self.dataset_info[tier][i] for i in sorted_indices] | |
row_sums_sorted = row_sums[sorted_indices] | |
# Check for zero rows after sorting | |
zero_rows = (row_sums_sorted == 0).squeeze() | |
fig, ax = plt.subplots(figsize=(matrix.size(0), matrix.size(0)), dpi=140) | |
ax.matshow(matrix_percent.cpu().numpy(), cmap='viridis') | |
ax.xaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(1) + 1))) | |
ax.yaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(0) + 1))) | |
ax.set_xticklabels(class_labels + [''], rotation=45) | |
ax.set_yticklabels(class_labels + ['']) | |
# Add total number of instances to the y-axis labels | |
y_labels = [f'{class_labels[i]} [n={int(row_sums_sorted[i].item()):,.0f}]'.replace(',', "'") for i in range(matrix.size(0))] | |
ax.set_yticklabels(y_labels + ['']) | |
ax.set_xlabel('Predictions') | |
ax.set_ylabel('Targets') | |
# Move x-axis label and ticks to the top | |
ax.xaxis.set_label_position('top') | |
ax.xaxis.set_ticks_position('top') | |
fig.tight_layout() | |
for i in range(matrix.size(0)): | |
for j in range(matrix.size(1)): | |
if zero_rows[i]: | |
ax.text(j, i, 'N/A', ha='center', va='center', color='black') | |
else: | |
ax.text(j, i, f'{matrix_percent[i, j]:.2f}', ha='center', va='center', color='#F88379', weight='bold') # coral red | |
trainer.logger.experiment.log({f"{phase}_{tier}_confusion_matrix_{mode}": wandb.Image(fig)}) | |
plt.close() | |
confusion_matrix.reset() | |
class LogMessisMetrics(pl.Callback): | |
def __init__(self, hparams, dataset_info_file, debug=False): | |
super().__init__() | |
assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams" | |
self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)} | |
self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None) | |
self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None) | |
assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True" | |
assert self.final_head_name is not None, "No head found with 'is_final_head' set to True" | |
self.tiers = list(self.tiers_dict.keys()) | |
self.phases = ['train', 'val', 'test'] | |
self.modes = ['pixelwise', 'majority'] | |
self.debug = debug | |
if debug: | |
print(f"Last tier identified as: {self.last_tier_name}") | |
print(f"Final head identified as: {self.final_head_name}") | |
print(f"LogMessisMetrics Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}") | |
with open(dataset_info_file, 'r') as f: | |
self.dataset_info = json.load(f) | |
# Initialize metrics | |
self.metrics_to_compute = ['accuracy', 'weighted_accuracy', 'precision', 'weighted_precision', 'recall', 'weighted_recall' ,'f1', 'weighted_f1', 'cohen_kappa'] | |
self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases} | |
self.images_to_log = {phase: {mode: None for mode in self.modes} for phase in self.phases} | |
self.images_to_log_targets = {phase: None for phase in self.phases} | |
self.field_ids_to_log_targets = {phase: None for phase in self.phases} | |
self.inputs_to_log = {phase: None for phase in self.phases} | |
def __init_metrics(self, tier, phase): | |
num_classes = self.tiers_dict[tier]['num_classes_to_predict'] | |
accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='macro') | |
weighted_accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='weighted') | |
per_class_accuracies = { | |
class_index: classification.BinaryAccuracy() for class_index in range(num_classes) | |
} | |
precision = classification.MulticlassPrecision(num_classes=num_classes, average='macro') | |
weighted_precision = classification.MulticlassPrecision(num_classes=num_classes, average='weighted') | |
recall = classification.MulticlassRecall(num_classes=num_classes, average='macro') | |
weighted_recall = classification.MulticlassRecall(num_classes=num_classes, average='weighted') | |
f1 = classification.MulticlassF1Score(num_classes=num_classes, average='macro') | |
weighted_f1 = classification.MulticlassF1Score(num_classes=num_classes, average='weighted') | |
cohen_kappa = classification.MulticlassCohenKappa(num_classes=num_classes) | |
return { | |
'accuracy': accuracy, | |
'weighted_accuracy': weighted_accuracy, | |
'per_class_accuracies': per_class_accuracies, | |
'precision': precision, | |
'weighted_precision': weighted_precision, | |
'recall': recall, | |
'weighted_recall': weighted_recall, | |
'f1': f1, | |
'weighted_f1': weighted_f1, | |
'cohen_kappa': cohen_kappa | |
} | |
def setup(self, trainer, pl_module, stage=None): | |
# Move all metrics to the correct device at the start of the training/validation | |
device = pl_module.device | |
for phase_metrics in self.metrics.values(): | |
for tier_metrics in phase_metrics.values(): | |
for mode_metrics in tier_metrics.values(): | |
for metric in self.metrics_to_compute: | |
mode_metrics[metric].to(device) | |
for class_accuracy in mode_metrics['per_class_accuracies'].values(): | |
class_accuracy.to(device) | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'train') | |
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'val') | |
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'test') | |
def __on_batch_end(self, trainer: pl.Trainer, pl_module, outputs, batch, batch_idx, phase): | |
if trainer.sanity_checking: | |
return | |
if self.debug: | |
print(f"{phase} batch ended. Updating metrics...") | |
targets = torch.stack(batch[1][0]) # (tiers, batch, H, W) | |
outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W) | |
field_ids = batch[1][1].permute(1, 0, 2, 3)[0] | |
pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info) | |
for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes): | |
# Update all metrics | |
assert preds.shape == targets.shape, f"Shapes of predictions and targets do not match: {preds.shape} vs {targets.shape}" | |
assert preds.shape[0] == len(self.tiers), f"Number of tiers in predictions and tiers do not match: {preds.shape[0]} vs {len(self.tiers)}" | |
self.images_to_log[phase][mode] = preds[-1] | |
for pred, target, tier in zip(preds, targets, self.tiers): | |
# flatten and remove background class if the mode is majority (such that the background class is not considered in the metrics) | |
if mode == 'majority': | |
pred = pred[target != 0] | |
target = target[target != 0] | |
metrics = self.metrics[phase][tier][mode] | |
for metric in self.metrics_to_compute: | |
metrics[metric].update(pred, target) | |
if self.debug: | |
print(f"{phase} {tier} {mode} {metric} updated. Update count: {metrics[metric]._update_count}") | |
self.__update_per_class_metrics(pred, target, metrics['per_class_accuracies']) | |
self.images_to_log_targets[phase] = targets[-1] | |
self.field_ids_to_log_targets[phase] = field_ids | |
self.inputs_to_log[phase] = batch[0] | |
def __update_per_class_metrics(self, preds, targets, per_class_accuracies): | |
for class_index, class_accuracy in per_class_accuracies.items(): | |
if not (targets == class_index).any(): | |
continue | |
if class_index == 0: | |
# Mask out non-background elements for background class (0) | |
class_mask = targets != 0 | |
else: | |
# Mask out background elements for other classes | |
class_mask = targets == 0 | |
preds_fields = preds[~class_mask] | |
targets_fields = targets[~class_mask] | |
# Prepare for binary classification (needs to be float) | |
preds_class = (preds_fields == class_index).float() | |
targets_class = (targets_fields == class_index).float() | |
class_accuracy.update(preds_class, targets_class) | |
if self.debug: | |
print(f"Shape of preds_fields: {preds_fields.shape}") | |
print(f"Shape of targets_fields: {targets_fields.shape}") | |
print(f"Unique values in preds_fields: {torch.unique(preds_fields)}") | |
print(f"Unique values in targets_fields: {torch.unique(targets_fields)}") | |
print(f"Per-class metrics for class {class_index} updated. Update count: {per_class_accuracies[class_index]._update_count}") | |
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | |
self.__on_epoch_end(trainer, pl_module, 'train') | |
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | |
self.__on_epoch_end(trainer, pl_module, 'val') | |
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | |
self.__on_epoch_end(trainer, pl_module, 'test') | |
def __on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, phase): | |
if trainer.sanity_checking: | |
return # Skip during sanity check (avoid warning about metric compute being called before update) | |
for tier in self.tiers: | |
for mode in self.modes: | |
metrics = self.metrics[phase][tier][mode] | |
# Calculate and reset in tier: Accuracy, WeightedAccuracy, Precision, Recall, F1, Cohen's Kappa | |
metrics_dict = {metric: metrics[metric].compute() for metric in self.metrics_to_compute} | |
pl_module.log_dict({f"{phase}_{metric}_{tier}_{mode}": v for metric, v in metrics_dict.items()}, on_step=False, on_epoch=True) | |
for metric in self.metrics_to_compute: | |
metrics[metric].reset() | |
# Per-class metrics | |
# NOTE: Some literature reports "per class accuracy" but what they actually mean is "per class recall". | |
# Using the accuracy formula per class has no value in our imbalanced multi-class setting (TN's inflate scores!) | |
# We calculate all 4 metrics. This allows us to calculate any macro/micro score later if needed. | |
class_metrics = [] | |
class_names_mapping = self.dataset_info[tier.split('_')[0] if '_refined' in tier else tier] | |
for class_index, class_accuracy in metrics['per_class_accuracies'].items(): | |
if class_accuracy._update_count == 0: | |
continue # Skip if no updates have been made | |
tp, tn, fp, fn = class_accuracy.tp, class_accuracy.tn, class_accuracy.fp, class_accuracy.fn | |
recall = (tp / (tp + fn)).item() if tp + fn > 0 else 0 | |
precision = (tp / (tp + fp)).item() if tp + fp > 0 else 0 | |
f1 = (2 * (precision * recall) / (precision + recall)) if precision + recall > 0 else 0 | |
n_of_class = (tp + fn).item() | |
class_metrics.append([class_index, class_names_mapping[class_index], precision, recall, f1, class_accuracy.compute().item(), n_of_class]) | |
class_accuracy.reset() | |
wandb_table = wandb.Table(data=class_metrics, columns=["Class Index", "Class Name", "Precision", "Recall", "F1", "Accuracy", "N"]) | |
trainer.logger.experiment.log({f"{phase}_per_class_metrics_{tier}_{mode}": wandb_table}) | |
# use the same n_classes for all images, such that they are comparable | |
n_classes = max([ | |
torch.max(self.images_to_log_targets[phase]), | |
torch.max(self.images_to_log[phase]["majority"]), | |
torch.max(self.images_to_log[phase]["pixelwise"]) | |
]) | |
images = [LogMessisMetrics.process_images(self.images_to_log[phase][mode], n_classes) for mode in self.modes] | |
images.append(LogMessisMetrics.create_positive_negative_image(self.images_to_log[phase]["majority"], self.images_to_log_targets[phase])) | |
images.append(LogMessisMetrics.process_images(self.images_to_log_targets[phase], n_classes)) | |
images.append(LogMessisMetrics.process_images(self.field_ids_to_log_targets[phase].cpu())) | |
examples = [] | |
for i in range(len(images[0])): | |
example = np.concatenate([img[i] for img in images], axis=0) | |
examples.append(wandb.Image(example, caption=f"From Top to Bottom: {self.modes[0]}, {self.modes[1]}, right/wrong classifications, target, fields")) | |
trainer.logger.experiment.log({f"{phase}_examples": examples}) | |
# Log segmentation masks | |
batch_input_data = self.inputs_to_log[phase].cpu() # shape [BS, 6, N_TIMESTEPS, 224, 224] | |
ground_truth_masks = self.images_to_log_targets[phase].cpu().numpy() | |
pixel_wise_masks = self.images_to_log[phase]["pixelwise"].cpu().numpy() | |
field_majority_masks = self.images_to_log[phase]["majority"].cpu().numpy() | |
correctness_masks = self.create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks) | |
class_labels = {idx: name for idx, name in enumerate(self.dataset_info[self.last_tier_name])} | |
segmentation_masks = [] | |
for input_data, ground_truth_mask, pixel_wise_mask, field_majority_mask, correctness_mask in zip(batch_input_data, ground_truth_masks, pixel_wise_masks, field_majority_masks, correctness_masks): | |
middle_timestep_index = input_data.shape[1] // 2 # Get the middle timestamp index | |
gamma = 2.5 # Gamma for brightness adjustment | |
rgb_image = input_data[:3, middle_timestep_index, :, :].permute(1, 2, 0).numpy() # Shape [224, 224, 3] | |
rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min()) | |
rgb_image = np.power(rgb_image, 1.0 / gamma) | |
rgb_image = (rgb_image * 255).astype(np.uint8) | |
mask_img = wandb.Image( | |
rgb_image, | |
masks={ | |
"predictions_pixel_wise": {"mask_data": pixel_wise_mask, "class_labels": class_labels}, | |
"predictions_field_majority": {"mask_data": field_majority_mask, "class_labels": class_labels}, | |
"ground_truth": {"mask_data": ground_truth_mask, "class_labels": class_labels}, | |
"correctness": {"mask_data": correctness_mask, "class_labels": { 0: "Background", 1: "Wrong", 2: "Right" }}, | |
}, | |
) | |
segmentation_masks.append(mask_img) | |
trainer.logger.experiment.log({f"{phase}_segmentation_mask": segmentation_masks}) | |
if self.debug: | |
print(f"{phase} epoch ended. Logging & resetting metrics...", trainer.sanity_checking) | |
def create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks): | |
""" | |
Create a tensor that shows the positive and negative classifications of the model. | |
Args: | |
field_majority_masks (np.ndarray): The field majority masks generated by the model. | |
ground_truth_masks (np.ndarray): The ground truth masks. | |
Returns: | |
np.ndarray: An array with values: | |
- 0 where the target is 0, | |
- 2 where the prediction matches the target, | |
- 1 where the prediction does not match the target. | |
""" | |
correctness_mask = np.zeros_like(ground_truth_masks, dtype=int) | |
matches = (field_majority_masks == ground_truth_masks) & (ground_truth_masks != 0) | |
correctness_mask[matches] = 2 | |
mismatches = (field_majority_masks != ground_truth_masks) & (ground_truth_masks != 0) | |
correctness_mask[mismatches] = 1 | |
return correctness_mask | |
def create_positive_negative_image(generated_images, target_images): | |
""" | |
Create an image that shows the positive and negative classifications of the model. | |
Args: | |
generated_images (torch.Tensor): The images generated by the model. | |
target_images (torch.Tensor): The target images. | |
Returns: | |
list: A list of processed images. | |
""" | |
classification_masks = generated_images == target_images | |
processed_imgs = [] | |
for mask, target in zip(classification_masks, target_images): | |
# color the background white, right classifications green, wrong classifications red | |
colored_img = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8) | |
mask = mask.bool() # Convert to boolean tensor | |
colored_img[mask] = torch.tensor([0, 255, 0], dtype=torch.uint8) | |
colored_img[~mask] = torch.tensor([255, 0, 0], dtype=torch.uint8) | |
colored_img[target == 0] = torch.tensor([0, 0, 0], dtype=torch.uint8) | |
processed_imgs.append(colored_img.cpu()) | |
return processed_imgs | |
def process_images(imgs, max=None): | |
""" | |
Process a batch of images to be logged on wandb. | |
Args: | |
imgs (torch.Tensor): A batch of images with shape (B, H, W) to be processed. | |
max (float, optional): The maximum value to normalize the images. Defaults to None. If None, the maximum value in the batch is used. | |
""" | |
if max is None: | |
max = np.max(imgs.cpu().numpy()) | |
normalized_img = imgs / max | |
processed_imgs = [] | |
for img in normalized_img.cpu().numpy(): | |
if max < 60: | |
cmap = ListedColormap(plt.get_cmap('tab20').colors + plt.get_cmap('tab20b').colors + plt.get_cmap('tab20c').colors) | |
else: | |
cmap = plt.get_cmap('viridis') | |
colored_img = cmap(img) | |
colored_img[img == 0] = [0, 0, 0, 1] | |
colored_img_uint8 = (colored_img[:, :, :3] * 255).astype(np.uint8) | |
processed_imgs.append(colored_img_uint8) | |
return processed_imgs |