ironjr's picture
untroubled files first
24f9881
raw
history blame
6.48 kB
# MIT License
# Copyright (c) 2022 Intelligent Systems Lab Org
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# File author: Shariq Farooq Bhat
import argparse
from pprint import pprint
import torch
from zoedepth.utils.easydict import EasyDict as edict
from tqdm import tqdm
from zoedepth.data.data_mono import DepthDataLoader
from zoedepth.models.builder import build_model
from zoedepth.utils.arg_utils import parse_unknown
from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR
from zoedepth.utils.misc import (RunningAverageDict, colors, compute_metrics,
count_parameters)
@torch.no_grad()
def infer(model, images, **kwargs):
"""Inference with flip augmentation"""
# images.shape = N, C, H, W
def get_depth_from_prediction(pred):
if isinstance(pred, torch.Tensor):
pred = pred # pass
elif isinstance(pred, (list, tuple)):
pred = pred[-1]
elif isinstance(pred, dict):
pred = pred['metric_depth'] if 'metric_depth' in pred else pred['out']
else:
raise NotImplementedError(f"Unknown output type {type(pred)}")
return pred
pred1 = model(images, **kwargs)
pred1 = get_depth_from_prediction(pred1)
pred2 = model(torch.flip(images, [3]), **kwargs)
pred2 = get_depth_from_prediction(pred2)
pred2 = torch.flip(pred2, [3])
mean_pred = 0.5 * (pred1 + pred2)
return mean_pred
@torch.no_grad()
def evaluate(model, test_loader, config, round_vals=True, round_precision=3):
model.eval()
metrics = RunningAverageDict()
for i, sample in tqdm(enumerate(test_loader), total=len(test_loader)):
if 'has_valid_depth' in sample:
if not sample['has_valid_depth']:
continue
image, depth = sample['image'], sample['depth']
image, depth = image.cuda(), depth.cuda()
depth = depth.squeeze().unsqueeze(0).unsqueeze(0)
focal = sample.get('focal', torch.Tensor(
[715.0873]).cuda()) # This magic number (focal) is only used for evaluating BTS model
pred = infer(model, image, dataset=sample['dataset'][0], focal=focal)
# Save image, depth, pred for visualization
if "save_images" in config and config.save_images:
import os
# print("Saving images ...")
from PIL import Image
import torchvision.transforms as transforms
from zoedepth.utils.misc import colorize
os.makedirs(config.save_images, exist_ok=True)
# def save_image(img, path):
d = colorize(depth.squeeze().cpu().numpy(), 0, 10)
p = colorize(pred.squeeze().cpu().numpy(), 0, 10)
im = transforms.ToPILImage()(image.squeeze().cpu())
im.save(os.path.join(config.save_images, f"{i}_img.png"))
Image.fromarray(d).save(os.path.join(config.save_images, f"{i}_depth.png"))
Image.fromarray(p).save(os.path.join(config.save_images, f"{i}_pred.png"))
# print(depth.shape, pred.shape)
metrics.update(compute_metrics(depth, pred, config=config))
if round_vals:
def r(m): return round(m, round_precision)
else:
def r(m): return m
metrics = {k: r(v) for k, v in metrics.get_value().items()}
return metrics
def main(config):
model = build_model(config)
test_loader = DepthDataLoader(config, 'online_eval').data
model = model.cuda()
metrics = evaluate(model, test_loader, config)
print(f"{colors.fg.green}")
print(metrics)
print(f"{colors.reset}")
metrics['#params'] = f"{round(count_parameters(model, include_all=True)/1e6, 2)}M"
return metrics
def eval_model(model_name, pretrained_resource, dataset='nyu', **kwargs):
# Load default pretrained resource defined in config if not set
overwrite = {**kwargs, "pretrained_resource": pretrained_resource} if pretrained_resource else kwargs
config = get_config(model_name, "eval", dataset, **overwrite)
# config = change_dataset(config, dataset) # change the dataset
pprint(config)
print(f"Evaluating {model_name} on {dataset}...")
metrics = main(config)
return metrics
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str,
required=True, help="Name of the model to evaluate")
parser.add_argument("-p", "--pretrained_resource", type=str,
required=False, default=None, help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used, Refer models.model_io.load_state_from_resource for more details.")
parser.add_argument("-d", "--dataset", type=str, required=False,
default='nyu', help="Dataset to evaluate on")
args, unknown_args = parser.parse_known_args()
overwrite_kwargs = parse_unknown(unknown_args)
if "ALL_INDOOR" in args.dataset:
datasets = ALL_INDOOR
elif "ALL_OUTDOOR" in args.dataset:
datasets = ALL_OUTDOOR
elif "ALL" in args.dataset:
datasets = ALL_EVAL_DATASETS
elif "," in args.dataset:
datasets = args.dataset.split(",")
else:
datasets = [args.dataset]
for dataset in datasets:
eval_model(args.model, pretrained_resource=args.pretrained_resource,
dataset=dataset, **overwrite_kwargs)