MultiMAE / utils /log_images.py
Bachmann Roman Christian
Initial commit
3b49518
raw
history blame contribute delete
No virus
4.13 kB
# Copyright (c) EPFL VILAB.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import wandb
import utils
from utils.datasets_semseg import (ade_classes, hypersim_classes,
nyu_v2_40_classes)
def inv_norm(tensor: torch.Tensor) -> torch.Tensor:
"""Inverse of the normalization that was done during pre-processing
"""
inv_normalize = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225])
return inv_normalize(tensor)
@torch.no_grad()
def log_semseg_wandb(
images: torch.Tensor,
preds: List[np.ndarray],
gts: List[np.ndarray],
depth_gts: List[np.ndarray],
dataset_name: str = 'ade20k',
image_count=8,
prefix=""
):
if dataset_name == 'ade20k':
classes = ade_classes()
elif dataset_name == 'hypersim':
classes = hypersim_classes()
elif dataset_name == 'nyu':
classes = nyu_v2_40_classes()
else:
raise ValueError(f'Dataset {dataset_name} not supported for logging to wandb.')
class_labels = {i: cls for i, cls in enumerate(classes)}
class_labels[len(classes)] = "void"
class_labels[utils.SEG_IGNORE_INDEX] = "ignore"
image_count = min(len(images), image_count)
images = images[:image_count]
preds = preds[:image_count]
gts = gts[:image_count]
depth_gts = depth_gts[:image_count] if len(depth_gts) > 0 else None
semseg_images = {}
for i, (image, pred, gt) in enumerate(zip(images, preds, gts)):
image = inv_norm(image)
pred[gt == utils.SEG_IGNORE_INDEX] = utils.SEG_IGNORE_INDEX
semseg_image = wandb.Image(image, masks={
"predictions": {
"mask_data": pred,
"class_labels": class_labels,
},
"ground_truth": {
"mask_data": gt,
"class_labels": class_labels,
}
})
semseg_images[f"{prefix}_{i}"] = semseg_image
if depth_gts is not None:
semseg_images[f"{prefix}_{i}_depth"] = wandb.Image(depth_gts[i])
wandb.log(semseg_images, commit=False)
@torch.no_grad()
def log_taskonomy_wandb(
preds: Dict[str, torch.Tensor],
gts: Dict[str, torch.Tensor],
image_count=8,
prefix=""
):
pred_tasks = list(preds.keys())
gt_tasks = list(gts.keys())
if 'mask_valid' in gt_tasks:
gt_tasks.remove('mask_valid')
image_count = min(len(preds[pred_tasks[0]]), image_count)
all_images = {}
for i in range(image_count):
# Log GTs
for task in gt_tasks:
gt_img = gts[task][i]
if task == 'rgb':
gt_img = inv_norm(gt_img)
if gt_img.shape[0] == 1:
gt_img = gt_img[0]
elif gt_img.shape[0] == 2:
gt_img = F.pad(gt_img, (0,0,0,0,0,1), mode='constant', value=0.0)
gt_img = wandb.Image(gt_img, caption=f'GT #{i}')
key = f'{prefix}_gt_{task}'
if key not in all_images:
all_images[key] = [gt_img]
else:
all_images[key].append(gt_img)
# Log preds
for task in pred_tasks:
pred_img = preds[task][i]
if task == 'rgb':
pred_img = inv_norm(pred_img)
if pred_img.shape[0] == 1:
pred_img = pred_img[0]
elif pred_img.shape[0] == 2:
pred_img = F.pad(pred_img, (0,0,0,0,0,1), mode='constant', value=0.0)
pred_img = wandb.Image(pred_img, caption=f'Pred #{i}')
key = f'{prefix}_pred_{task}'
if key not in all_images:
all_images[key] = [pred_img]
else:
all_images[key].append(pred_img)
wandb.log(all_images, commit=False)