messis-demo / messis /messis.py
yvokeller's picture
first messis demo app version
5b24075
raw
history blame
45.3 kB
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchmetrics import classification
import wandb
from matplotlib import pyplot as plt
import numpy as np
import matplotlib.ticker as ticker
from matplotlib.colors import ListedColormap
from huggingface_hub import PyTorchModelHubMixin
from lion_pytorch import Lion
import json
from messis.prithvi import TemporalViTEncoder, ConvTransformerTokensToEmbeddingNeck, ConvTransformerTokensToEmbeddingBottleneckNeck
def safe_shape(x):
if isinstance(x, tuple):
# loop through tuple
shape_info = '(tuple) : '
for i in x:
shape_info += str(i.shape) + ', '
return shape_info
if isinstance(x, list):
# loop through list
shape_info = '(list) : '
for i in x:
shape_info += str(i.shape) + ', '
return shape_info
return x.shape
class ConvModule(nn.Module):
"""
A simple convolutional module including Conv, BatchNorm, and ReLU layers.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
super(ConvModule, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class HierarchicalFCNHead(nn.Module):
"""
Hierarchical FCN Head for semantic segmentation.
"""
def __init__(self, in_channels, out_channels, num_classes, num_convs=2, kernel_size=3, dilation=1, dropout_p=0.1, debug=False):
super(HierarchicalFCNHead, self).__init__()
self.debug = debug
self.convs = nn.Sequential(*[
ConvModule(
in_channels if i == 0 else out_channels,
out_channels,
kernel_size,
padding=dilation * (kernel_size // 2),
dilation=dilation
) for i in range(num_convs)
])
self.conv_seg = nn.Conv2d(out_channels, num_classes, kernel_size=1)
self.dropout = nn.Dropout2d(p=dropout_p)
def forward(self, x):
if self.debug:
print('HierarchicalFCNHead forward INP: ', safe_shape(x))
x = self.convs(x)
features = self.dropout(x)
output = self.conv_seg(features)
if self.debug:
print('HierarchicalFCNHead forward features OUT: ', safe_shape(features))
print('HierarchicalFCNHead forward output OUT: ', safe_shape(output))
return output, features
class LabelRefinementHead(nn.Module):
"""
Similar to the label refinement module introduced in the ZueriCrop paper, this module refines the predictions for tier 3.
It takes the raw predictions from head 1, head 2 and head 3 and refines them to produce the final prediction for tier 3.
According to ZueriCrop, this helps with making the predictions more consistent across the different tiers.
"""
def __init__(self, input_channels, num_classes):
super(LabelRefinementHead, self).__init__()
self.cnn_layers = nn.Sequential(
# 1x1 Convolutional layer
nn.Conv2d(in_channels=input_channels, out_channels=128, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# 3x3 Convolutional layer
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
# Skip connection (implemented in forward method)
# Another 3x3 Convolutional layer
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# 1x1 Convolutional layer to adjust the number of output channels to num_classes
nn.Conv2d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
nn.Dropout(p=0.5)
)
def forward(self, x):
# Apply initial conv layer
y = self.cnn_layers[0:3](x)
# Save for skip connection
y_skip = y
# Apply the next two conv layers
y = self.cnn_layers[3:9](y)
# Skip connection (element-wise addition)
y = y + y_skip
# Apply the last conv layer
y = self.cnn_layers[9:](y)
return y
class HierarchicalClassifier(nn.Module):
def __init__(
self,
heads_spec,
dropout_p=0.1,
img_size=256,
patch_size=16,
num_frames=3,
bands=[0, 1, 2, 3, 4, 5],
backbone_weights_path=None,
freeze_backbone=True,
use_bottleneck_neck=False,
bottleneck_reduction_factor=4,
loss_ignore_background=False,
debug=False
):
super(HierarchicalClassifier, self).__init__()
self.embed_dim = 768
if num_frames % 3 != 0:
raise ValueError("The number of frames must be a multiple of 3, it is currently: ", num_frames)
self.num_frames = num_frames
self.hp, self.wp = img_size // patch_size, img_size // patch_size
self.heads_spec = heads_spec
self.dropout_p = dropout_p
self.loss_ignore_background = loss_ignore_background
self.debug = debug
if self.debug:
print('hp and wp: ', self.hp, self.wp)
self.prithvi = TemporalViTEncoder(
img_size=img_size,
patch_size=patch_size,
num_frames=3,
tubelet_size=1,
in_chans=len(bands),
embed_dim=self.embed_dim,
depth=12,
num_heads=8,
mlp_ratio=4.0,
norm_pix_loss=False,
pretrained=backbone_weights_path,
debug=self.debug
)
# (Un)freeze the backbone
for param in self.prithvi.parameters():
param.requires_grad = not freeze_backbone
# Neck to transform the token-based output of the transformer into a spatial feature map
number_of_necks = self.num_frames // 3
if use_bottleneck_neck:
self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingBottleneckNeck(
embed_dim=self.embed_dim * 3,
output_embed_dim=self.embed_dim * 3,
drop_cls_token=True,
Hp=self.hp,
Wp=self.wp,
bottleneck_reduction_factor=bottleneck_reduction_factor
) for _ in range(number_of_necks)])
else:
self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingNeck(
embed_dim=self.embed_dim * 3,
output_embed_dim=self.embed_dim * 3,
drop_cls_token=True,
Hp=self.hp,
Wp=self.wp,
) for _ in range(number_of_necks)])
# Initialize heads and loss weights based on tiers
self.heads = nn.ModuleDict()
self.loss_weights = {}
self.total_classes = 0
# Build HierarchicalFCNHeads
head_count = 0
for head_name, head_info in self.heads_spec.items():
head_type = head_info['type']
num_classes = head_info['num_classes_to_predict']
loss_weight = head_info['loss_weight']
if head_type == 'HierarchicalFCNHead':
num_classes = head_info['num_classes_to_predict']
loss_weight = head_info['loss_weight']
kernel_size = head_info.get('kernel_size', 3)
num_convs = head_info.get('num_convs', 1)
num_channels = head_info.get('num_channels', 256)
self.total_classes += num_classes
self.heads[head_name] = HierarchicalFCNHead(
in_channels=(self.embed_dim * self.num_frames) if head_count == 0 else num_channels,
out_channels=num_channels,
num_classes=num_classes,
num_convs=num_convs,
kernel_size=kernel_size,
dropout_p=self.dropout_p,
debug=self.debug
)
self.loss_weights[head_name] = loss_weight
# NOTE: LabelRefinementHead must be the last in the dict, otherwise the total_classes will be incorrect
if head_type == 'LabelRefinementHead':
self.refinement_head = LabelRefinementHead(input_channels=self.total_classes, num_classes=num_classes)
self.refinement_head_name = head_name
self.loss_weights[head_name] = loss_weight
head_count += 1
self.loss_func = nn.CrossEntropyLoss(ignore_index=-1)
def forward(self, x):
if self.debug:
print(f"Input shape: {safe_shape(x)}") # torch.Size([4, 6, 9, 224, 224])
# Extract features from the base model
if len(self.necks) == 1:
features = [x]
else:
features = torch.chunk(x, len(self.necks), dim=2)
features = [self.prithvi(x) for x in features]
if self.debug:
print(f"Features shape after base model: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 589, 768]), , (tuple) : torch.Size
# Process through the neck
features = [neck(feat_) for feat_, neck in zip(features, self.necks)]
if self.debug:
print(f"Features shape after neck: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 2304, 224, 224]), , (tuple) : torch.Size
# Remove from tuple
features = [feat[0] for feat in features]
# stack the features to create a tensor of torch.Size([4, 6912, 224, 224])
features = torch.concatenate(features, dim=1)
if self.debug:
print(f"Features shape after removing tuple: {safe_shape(features)}") # torch.Size([4, 6912, 224, 224])
# Process through the heads
outputs = {}
for tier_name, head in self.heads.items():
output, features = head(features)
outputs[tier_name] = output
if self.debug:
print(f"Features shape after {tier_name} head: {safe_shape(features)}")
print(f"Output shape after {tier_name} head: {safe_shape(output)}")
# Process through the classification refinement head
output_concatenated = torch.cat(list(outputs.values()), dim=1)
output_refinement_head = self.refinement_head(output_concatenated)
outputs[self.refinement_head_name] = output_refinement_head
return outputs
def calculate_loss(self, outputs, targets):
total_loss = 0
loss_per_head = {}
for head_name, output in outputs.items():
if self.debug:
print(f"Target index for {head_name}: {self.heads_spec[head_name]['target_idx']}")
target = targets[self.heads_spec[head_name]['target_idx']]
loss_target = target
if self.loss_ignore_background:
loss_target = target.clone() # Clone as original target needed in backward pass
loss_target[loss_target == 0] = -1 # Set background class to ignore_index -1 for loss calculation
loss = self.loss_func(output, loss_target)
loss_per_head[f'{head_name}'] = loss
total_loss += loss * self.loss_weights[head_name]
return total_loss, loss_per_head
class Messis(pl.LightningModule, PyTorchModelHubMixin):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
self.model = HierarchicalClassifier(
heads_spec=hparams['heads_spec'],
dropout_p=hparams.get('dropout_p'),
img_size=hparams.get('img_size'),
patch_size=hparams.get('patch_size'),
num_frames=hparams.get('num_frames'),
bands=hparams.get('bands'),
backbone_weights_path=hparams.get('backbone_weights_path'),
freeze_backbone=hparams['freeze_backbone'],
use_bottleneck_neck=hparams.get('use_bottleneck_neck'),
bottleneck_reduction_factor=hparams.get('bottleneck_reduction_factor'),
loss_ignore_background=hparams.get('loss_ignore_background'),
debug=hparams.get('debug')
)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
return self.__step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self.__step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self.__step(batch, batch_idx, "test")
def configure_optimizers(self):
# select case on optimizer
match self.hparams.get('optimizer', 'Adam'):
case 'Adam':
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.get('lr', 1e-3))
case 'AdamW':
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.get('lr', 1e-3), weight_decay=self.hparams.get('optimizer_weight_decay', 0.01))
case 'SGD':
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.get('lr', 1e-3), momentum=self.hparams.get('optimizer_momentum', 0.9))
case 'Lion':
# https://github.com/lucidrains/lion-pytorch | Typically lr 3-10 times lower than Adam and weight_decay 3-10 times higher
optimizer = Lion(self.parameters(), lr=self.hparams.get('lr', 1e-4), weight_decay=self.hparams.get('optimizer_weight_decay', 0.1))
case _:
raise ValueError(f"Optimizer {self.hparams.get('optimizer')} not supported")
return optimizer
def __step(self, batch, batch_idx, stage):
inputs, targets = batch
targets = torch.stack(targets[0])
outputs = self(inputs)
loss, loss_per_head = self.model.calculate_loss(outputs, targets)
loss_per_head_named = {f'{stage}_loss_{head}': loss_per_head[head] for head in loss_per_head}
loss_proportions = { f'{stage}_loss_{head}_proportion': round(loss_per_head[head].item() / loss.item(), 2) for head in loss_per_head}
loss_detail_dict = {**loss_per_head_named, **loss_proportions}
if self.hparams.get('debug'):
print(f"Step Inputs shape: {safe_shape(inputs)}")
print(f"Step Targets shape: {safe_shape(targets)}")
print(f"Step Outputs dict keys: {outputs.keys()}")
# NOTE: All metrics other than loss are tracked by callbacks (LogMessisMetrics)
self.log_dict({f'{stage}_loss': loss, **loss_detail_dict}, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return {'loss': loss, 'outputs': outputs}
class LogConfusionMatrix(pl.Callback):
def __init__(self, hparams, dataset_info_file, debug=False):
super().__init__()
assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams"
self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None)
self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None)
assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True"
assert self.final_head_name is not None, "No head found with 'is_final_head' set to True"
self.tiers = list(self.tiers_dict.keys())
self.phases = ['train', 'val', 'test']
self.modes = ['pixelwise', 'majority']
self.debug = debug
if debug:
print(f"Final head identified as: {self.final_head_name}")
print(f"LogConfusionMatrix Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}")
with open(dataset_info_file, 'r') as f:
self.dataset_info = json.load(f)
# Initialize confusion matrices
self.metrics_to_compute = ['confusion_matrix']
self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases}
def __init_metrics(self, tier, phase):
num_classes = self.tiers_dict[tier]['num_classes_to_predict']
confusion_matrix = classification.MulticlassConfusionMatrix(num_classes=num_classes)
return {
'confusion_matrix': confusion_matrix
}
def setup(self, trainer, pl_module, stage=None):
# Move all metrics to the correct device at the start of the training/validation
device = pl_module.device
for phase_metrics in self.metrics.values():
for tier_metrics in phase_metrics.values():
for mode_metrics in tier_metrics.values():
for metric in self.metrics_to_compute:
mode_metrics[metric].to(device)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'train')
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'val')
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'test')
def __update_confusion_matrices(self, trainer, pl_module, outputs, batch, batch_idx, phase):
if trainer.sanity_checking:
return
targets = torch.stack(batch[1][0]) # (tiers, batch, H, W)
outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W)
field_ids = batch[1][1].permute(1, 0, 2, 3)[0]
pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info)
for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes):
# Update all metrics
assert len(preds) == len(targets), f"Number of predictions and targets do not match: {len(preds)} vs {len(targets)}"
assert len(preds) == len(self.tiers), f"Number of predictions and tiers do not match: {len(preds)} vs {len(self.tiers)}"
for pred, target, tier in zip(preds, targets, self.tiers):
if self.debug:
print(f"Updating confusion matrix for {phase} {tier} {mode}")
metrics = self.metrics[phase][tier][mode]
# flatten and remove background class if the mode is majority (such that the background class is not included in the confusion matrix)
if mode == 'majority':
pred = pred[target != 0]
target = target[target != 0]
metrics['confusion_matrix'].update(pred, target)
@staticmethod
def get_pixelwise_and_majority_outputs(refinement_head_outputs, tiers, field_ids, dataset_info):
"""
Get the pixelwise and majority predictions from the model outputs.
The pixelwise tier predictions are derived from the refinement_head_outputs predictions.
The majority last tier predictions are derived from the refinement_head_outputs. And then the majority lower-tier predictions are derived from the majority highest-tier predictions.
Also sets the background to 0 for all field majority predictions (regardless of what the model predicts for the background class).
As this is a classification task and not a segmentation task and the field boundaries are known beforehand and not of any interest.
Args:
refinement_head_outputs (torch.Tensor(batch, C, H, W)): The probability outputs from the model for the refined tier.
tiers (list of str): List of tiers e.g. ['tier1', 'tier2', 'tier3'].
field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction.
dataset_info (dict): The dataset information.
Returns:
torch.Tensor(tiers, batch, H, W): The pixelwise predictions.
torch.Tensor(tiers, batch, H, W): The majority predictions.
"""
# Assuming the highest tier is the last one in the list
highest_tier = tiers[-1]
pixelwise_highest_tier = torch.softmax(refinement_head_outputs, dim=1).argmax(dim=1) # (batch, H, W)
majority_highest_tier = LogConfusionMatrix.get_field_majority_preds(refinement_head_outputs, field_ids)
tier_mapping = {tier: dataset_info[f'{highest_tier}_to_{tier}'] for tier in tiers if tier != highest_tier}
pixelwise_outputs = {highest_tier: pixelwise_highest_tier}
majority_outputs = {highest_tier: majority_highest_tier}
# Initialize pixelwise and majority outputs for each tier
for tier in tiers:
if tier != highest_tier:
pixelwise_outputs[tier] = torch.zeros_like(pixelwise_highest_tier)
majority_outputs[tier] = torch.zeros_like(majority_highest_tier)
# Map the highest tier to lower tiers
for i, mappings in enumerate(zip(*tier_mapping.values())):
for j, tier in enumerate(tier_mapping.keys()):
pixelwise_outputs[tier][pixelwise_highest_tier == i] = mappings[j]
majority_outputs[tier][majority_highest_tier == i] = mappings[j]
pixelwise_outputs_stacked = torch.stack([pixelwise_outputs[tier] for tier in tiers])
majority_outputs_stacked = torch.stack([majority_outputs[tier] for tier in tiers])
# Ensure these are tensors
assert isinstance(pixelwise_outputs_stacked, torch.Tensor), "pixelwise_outputs_stacked is not a tensor"
assert isinstance(majority_outputs_stacked, torch.Tensor), "majority_outputs_stacked is not a tensor"
return pixelwise_outputs_stacked, majority_outputs_stacked
@staticmethod
def get_field_majority_preds(output, field_ids):
"""
Get the majority prediction for each field in the batch. The majority excludes the background class.
Args:
output (torch.Tensor(batch, C, H, W)): The probability outputs from the model (tier3_refined)
field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction.
Returns:
torch.Tensor(batch, H, W): The majority predictions.
"""
# remove the background class
pixelwise = torch.softmax(output[:, 1:, :, :], dim=1).argmax(dim=1) + 1 # (batch, H, W)
majority_preds = torch.zeros_like(pixelwise)
for batch in range(len(pixelwise)):
field_ids_batch = field_ids[batch]
for field_id in np.unique(field_ids_batch.cpu().numpy()):
if field_id == 0:
continue
field_mask = field_ids_batch == field_id
flattened_pred = pixelwise[batch][field_mask].view(-1) # Flatten the prediction
flattened_pred = flattened_pred[flattened_pred != 0] # Exclude background class
if len(flattened_pred) == 0:
continue
mode_pred, _ = torch.mode(flattened_pred) # Compute mode prediction
majority_preds[batch][field_mask] = mode_pred.item()
return majority_preds
def on_train_epoch_end(self, trainer, pl_module):
# Log and then reset the confusion matrices after training epoch
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'train')
def on_validation_epoch_end(self, trainer, pl_module):
# Log and then reset the confusion matrices after validation epoch
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'val')
def on_test_epoch_end(self, trainer, pl_module):
# Log and then reset the confusion matrices after test epoch
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'test')
def __log_and_reset_confusion_matrices(self, trainer, pl_module, phase):
if trainer.sanity_checking:
return
for tier in self.tiers:
for mode in self.modes:
metrics = self.metrics[phase][tier][mode]
confusion_matrix = metrics['confusion_matrix']
if self.debug:
print(f"Logging and resetting confusion matrix for {phase} {tier} Update count: {confusion_matrix._update_count}")
matrix = confusion_matrix.compute() # columns are predictions and rows are targets
# Calculate percentages
matrix = matrix.float()
row_sums = matrix.sum(dim=1, keepdim=True)
matrix_percent = matrix / row_sums
# Ensure percentages sum to 1 for each row or handle NaNs
row_sum_check = matrix_percent.sum(dim=1)
valid_rows = ~torch.isnan(row_sum_check)
if valid_rows.any():
assert torch.allclose(row_sum_check[valid_rows], torch.ones_like(row_sum_check[valid_rows]), atol=1e-2), "Percentages do not sum to 1 for some valid rows"
# Sort the matrix and labels by the total number of instances
sorted_indices = row_sums.squeeze().argsort(descending=True)
matrix_percent = matrix_percent[sorted_indices, :] # sort rows
matrix_percent = matrix_percent[:, sorted_indices] # sort columns
class_labels = [self.dataset_info[tier][i] for i in sorted_indices]
row_sums_sorted = row_sums[sorted_indices]
# Check for zero rows after sorting
zero_rows = (row_sums_sorted == 0).squeeze()
fig, ax = plt.subplots(figsize=(matrix.size(0), matrix.size(0)), dpi=140)
ax.matshow(matrix_percent.cpu().numpy(), cmap='viridis')
ax.xaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(1) + 1)))
ax.yaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(0) + 1)))
ax.set_xticklabels(class_labels + [''], rotation=45)
ax.set_yticklabels(class_labels + [''])
# Add total number of instances to the y-axis labels
y_labels = [f'{class_labels[i]} [n={int(row_sums_sorted[i].item()):,.0f}]'.replace(',', "'") for i in range(matrix.size(0))]
ax.set_yticklabels(y_labels + [''])
ax.set_xlabel('Predictions')
ax.set_ylabel('Targets')
# Move x-axis label and ticks to the top
ax.xaxis.set_label_position('top')
ax.xaxis.set_ticks_position('top')
fig.tight_layout()
for i in range(matrix.size(0)):
for j in range(matrix.size(1)):
if zero_rows[i]:
ax.text(j, i, 'N/A', ha='center', va='center', color='black')
else:
ax.text(j, i, f'{matrix_percent[i, j]:.2f}', ha='center', va='center', color='#F88379', weight='bold') # coral red
trainer.logger.experiment.log({f"{phase}_{tier}_confusion_matrix_{mode}": wandb.Image(fig)})
plt.close()
confusion_matrix.reset()
class LogMessisMetrics(pl.Callback):
def __init__(self, hparams, dataset_info_file, debug=False):
super().__init__()
assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams"
self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None)
self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None)
assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True"
assert self.final_head_name is not None, "No head found with 'is_final_head' set to True"
self.tiers = list(self.tiers_dict.keys())
self.phases = ['train', 'val', 'test']
self.modes = ['pixelwise', 'majority']
self.debug = debug
if debug:
print(f"Last tier identified as: {self.last_tier_name}")
print(f"Final head identified as: {self.final_head_name}")
print(f"LogMessisMetrics Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}")
with open(dataset_info_file, 'r') as f:
self.dataset_info = json.load(f)
# Initialize metrics
self.metrics_to_compute = ['accuracy', 'weighted_accuracy', 'precision', 'weighted_precision', 'recall', 'weighted_recall' ,'f1', 'weighted_f1', 'cohen_kappa']
self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases}
self.images_to_log = {phase: {mode: None for mode in self.modes} for phase in self.phases}
self.images_to_log_targets = {phase: None for phase in self.phases}
self.field_ids_to_log_targets = {phase: None for phase in self.phases}
self.inputs_to_log = {phase: None for phase in self.phases}
def __init_metrics(self, tier, phase):
num_classes = self.tiers_dict[tier]['num_classes_to_predict']
accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='macro')
weighted_accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='weighted')
per_class_accuracies = {
class_index: classification.BinaryAccuracy() for class_index in range(num_classes)
}
precision = classification.MulticlassPrecision(num_classes=num_classes, average='macro')
weighted_precision = classification.MulticlassPrecision(num_classes=num_classes, average='weighted')
recall = classification.MulticlassRecall(num_classes=num_classes, average='macro')
weighted_recall = classification.MulticlassRecall(num_classes=num_classes, average='weighted')
f1 = classification.MulticlassF1Score(num_classes=num_classes, average='macro')
weighted_f1 = classification.MulticlassF1Score(num_classes=num_classes, average='weighted')
cohen_kappa = classification.MulticlassCohenKappa(num_classes=num_classes)
return {
'accuracy': accuracy,
'weighted_accuracy': weighted_accuracy,
'per_class_accuracies': per_class_accuracies,
'precision': precision,
'weighted_precision': weighted_precision,
'recall': recall,
'weighted_recall': weighted_recall,
'f1': f1,
'weighted_f1': weighted_f1,
'cohen_kappa': cohen_kappa
}
def setup(self, trainer, pl_module, stage=None):
# Move all metrics to the correct device at the start of the training/validation
device = pl_module.device
for phase_metrics in self.metrics.values():
for tier_metrics in phase_metrics.values():
for mode_metrics in tier_metrics.values():
for metric in self.metrics_to_compute:
mode_metrics[metric].to(device)
for class_accuracy in mode_metrics['per_class_accuracies'].values():
class_accuracy.to(device)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'train')
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'val')
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'test')
def __on_batch_end(self, trainer: pl.Trainer, pl_module, outputs, batch, batch_idx, phase):
if trainer.sanity_checking:
return
if self.debug:
print(f"{phase} batch ended. Updating metrics...")
targets = torch.stack(batch[1][0]) # (tiers, batch, H, W)
outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W)
field_ids = batch[1][1].permute(1, 0, 2, 3)[0]
pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info)
for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes):
# Update all metrics
assert preds.shape == targets.shape, f"Shapes of predictions and targets do not match: {preds.shape} vs {targets.shape}"
assert preds.shape[0] == len(self.tiers), f"Number of tiers in predictions and tiers do not match: {preds.shape[0]} vs {len(self.tiers)}"
self.images_to_log[phase][mode] = preds[-1]
for pred, target, tier in zip(preds, targets, self.tiers):
# flatten and remove background class if the mode is majority (such that the background class is not considered in the metrics)
if mode == 'majority':
pred = pred[target != 0]
target = target[target != 0]
metrics = self.metrics[phase][tier][mode]
for metric in self.metrics_to_compute:
metrics[metric].update(pred, target)
if self.debug:
print(f"{phase} {tier} {mode} {metric} updated. Update count: {metrics[metric]._update_count}")
self.__update_per_class_metrics(pred, target, metrics['per_class_accuracies'])
self.images_to_log_targets[phase] = targets[-1]
self.field_ids_to_log_targets[phase] = field_ids
self.inputs_to_log[phase] = batch[0]
def __update_per_class_metrics(self, preds, targets, per_class_accuracies):
for class_index, class_accuracy in per_class_accuracies.items():
if not (targets == class_index).any():
continue
if class_index == 0:
# Mask out non-background elements for background class (0)
class_mask = targets != 0
else:
# Mask out background elements for other classes
class_mask = targets == 0
preds_fields = preds[~class_mask]
targets_fields = targets[~class_mask]
# Prepare for binary classification (needs to be float)
preds_class = (preds_fields == class_index).float()
targets_class = (targets_fields == class_index).float()
class_accuracy.update(preds_class, targets_class)
if self.debug:
print(f"Shape of preds_fields: {preds_fields.shape}")
print(f"Shape of targets_fields: {targets_fields.shape}")
print(f"Unique values in preds_fields: {torch.unique(preds_fields)}")
print(f"Unique values in targets_fields: {torch.unique(targets_fields)}")
print(f"Per-class metrics for class {class_index} updated. Update count: {per_class_accuracies[class_index]._update_count}")
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
self.__on_epoch_end(trainer, pl_module, 'train')
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
self.__on_epoch_end(trainer, pl_module, 'val')
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
self.__on_epoch_end(trainer, pl_module, 'test')
def __on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, phase):
if trainer.sanity_checking:
return # Skip during sanity check (avoid warning about metric compute being called before update)
for tier in self.tiers:
for mode in self.modes:
metrics = self.metrics[phase][tier][mode]
# Calculate and reset in tier: Accuracy, WeightedAccuracy, Precision, Recall, F1, Cohen's Kappa
metrics_dict = {metric: metrics[metric].compute() for metric in self.metrics_to_compute}
pl_module.log_dict({f"{phase}_{metric}_{tier}_{mode}": v for metric, v in metrics_dict.items()}, on_step=False, on_epoch=True)
for metric in self.metrics_to_compute:
metrics[metric].reset()
# Per-class metrics
# NOTE: Some literature reports "per class accuracy" but what they actually mean is "per class recall".
# Using the accuracy formula per class has no value in our imbalanced multi-class setting (TN's inflate scores!)
# We calculate all 4 metrics. This allows us to calculate any macro/micro score later if needed.
class_metrics = []
class_names_mapping = self.dataset_info[tier.split('_')[0] if '_refined' in tier else tier]
for class_index, class_accuracy in metrics['per_class_accuracies'].items():
if class_accuracy._update_count == 0:
continue # Skip if no updates have been made
tp, tn, fp, fn = class_accuracy.tp, class_accuracy.tn, class_accuracy.fp, class_accuracy.fn
recall = (tp / (tp + fn)).item() if tp + fn > 0 else 0
precision = (tp / (tp + fp)).item() if tp + fp > 0 else 0
f1 = (2 * (precision * recall) / (precision + recall)) if precision + recall > 0 else 0
n_of_class = (tp + fn).item()
class_metrics.append([class_index, class_names_mapping[class_index], precision, recall, f1, class_accuracy.compute().item(), n_of_class])
class_accuracy.reset()
wandb_table = wandb.Table(data=class_metrics, columns=["Class Index", "Class Name", "Precision", "Recall", "F1", "Accuracy", "N"])
trainer.logger.experiment.log({f"{phase}_per_class_metrics_{tier}_{mode}": wandb_table})
# use the same n_classes for all images, such that they are comparable
n_classes = max([
torch.max(self.images_to_log_targets[phase]),
torch.max(self.images_to_log[phase]["majority"]),
torch.max(self.images_to_log[phase]["pixelwise"])
])
images = [LogMessisMetrics.process_images(self.images_to_log[phase][mode], n_classes) for mode in self.modes]
images.append(LogMessisMetrics.create_positive_negative_image(self.images_to_log[phase]["majority"], self.images_to_log_targets[phase]))
images.append(LogMessisMetrics.process_images(self.images_to_log_targets[phase], n_classes))
images.append(LogMessisMetrics.process_images(self.field_ids_to_log_targets[phase].cpu()))
examples = []
for i in range(len(images[0])):
example = np.concatenate([img[i] for img in images], axis=0)
examples.append(wandb.Image(example, caption=f"From Top to Bottom: {self.modes[0]}, {self.modes[1]}, right/wrong classifications, target, fields"))
trainer.logger.experiment.log({f"{phase}_examples": examples})
# Log segmentation masks
batch_input_data = self.inputs_to_log[phase].cpu() # shape [BS, 6, N_TIMESTEPS, 224, 224]
ground_truth_masks = self.images_to_log_targets[phase].cpu().numpy()
pixel_wise_masks = self.images_to_log[phase]["pixelwise"].cpu().numpy()
field_majority_masks = self.images_to_log[phase]["majority"].cpu().numpy()
correctness_masks = self.create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks)
class_labels = {idx: name for idx, name in enumerate(self.dataset_info[self.last_tier_name])}
segmentation_masks = []
for input_data, ground_truth_mask, pixel_wise_mask, field_majority_mask, correctness_mask in zip(batch_input_data, ground_truth_masks, pixel_wise_masks, field_majority_masks, correctness_masks):
middle_timestep_index = input_data.shape[1] // 2 # Get the middle timestamp index
gamma = 2.5 # Gamma for brightness adjustment
rgb_image = input_data[:3, middle_timestep_index, :, :].permute(1, 2, 0).numpy() # Shape [224, 224, 3]
rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min())
rgb_image = np.power(rgb_image, 1.0 / gamma)
rgb_image = (rgb_image * 255).astype(np.uint8)
mask_img = wandb.Image(
rgb_image,
masks={
"predictions_pixel_wise": {"mask_data": pixel_wise_mask, "class_labels": class_labels},
"predictions_field_majority": {"mask_data": field_majority_mask, "class_labels": class_labels},
"ground_truth": {"mask_data": ground_truth_mask, "class_labels": class_labels},
"correctness": {"mask_data": correctness_mask, "class_labels": { 0: "Background", 1: "Wrong", 2: "Right" }},
},
)
segmentation_masks.append(mask_img)
trainer.logger.experiment.log({f"{phase}_segmentation_mask": segmentation_masks})
if self.debug:
print(f"{phase} epoch ended. Logging & resetting metrics...", trainer.sanity_checking)
@staticmethod
def create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks):
"""
Create a tensor that shows the positive and negative classifications of the model.
Args:
field_majority_masks (np.ndarray): The field majority masks generated by the model.
ground_truth_masks (np.ndarray): The ground truth masks.
Returns:
np.ndarray: An array with values:
- 0 where the target is 0,
- 2 where the prediction matches the target,
- 1 where the prediction does not match the target.
"""
correctness_mask = np.zeros_like(ground_truth_masks, dtype=int)
matches = (field_majority_masks == ground_truth_masks) & (ground_truth_masks != 0)
correctness_mask[matches] = 2
mismatches = (field_majority_masks != ground_truth_masks) & (ground_truth_masks != 0)
correctness_mask[mismatches] = 1
return correctness_mask
@staticmethod
def create_positive_negative_image(generated_images, target_images):
"""
Create an image that shows the positive and negative classifications of the model.
Args:
generated_images (torch.Tensor): The images generated by the model.
target_images (torch.Tensor): The target images.
Returns:
list: A list of processed images.
"""
classification_masks = generated_images == target_images
processed_imgs = []
for mask, target in zip(classification_masks, target_images):
# color the background white, right classifications green, wrong classifications red
colored_img = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8)
mask = mask.bool() # Convert to boolean tensor
colored_img[mask] = torch.tensor([0, 255, 0], dtype=torch.uint8)
colored_img[~mask] = torch.tensor([255, 0, 0], dtype=torch.uint8)
colored_img[target == 0] = torch.tensor([0, 0, 0], dtype=torch.uint8)
processed_imgs.append(colored_img.cpu())
return processed_imgs
@staticmethod
def process_images(imgs, max=None):
"""
Process a batch of images to be logged on wandb.
Args:
imgs (torch.Tensor): A batch of images with shape (B, H, W) to be processed.
max (float, optional): The maximum value to normalize the images. Defaults to None. If None, the maximum value in the batch is used.
"""
if max is None:
max = np.max(imgs.cpu().numpy())
normalized_img = imgs / max
processed_imgs = []
for img in normalized_img.cpu().numpy():
if max < 60:
cmap = ListedColormap(plt.get_cmap('tab20').colors + plt.get_cmap('tab20b').colors + plt.get_cmap('tab20c').colors)
else:
cmap = plt.get_cmap('viridis')
colored_img = cmap(img)
colored_img[img == 0] = [0, 0, 0, 1]
colored_img_uint8 = (colored_img[:, :, :3] * 255).astype(np.uint8)
processed_imgs.append(colored_img_uint8)
return processed_imgs