Spaces:
Runtime error
Runtime error
File size: 5,452 Bytes
9a960ac df37aa3 9a960ac df37aa3 9a960ac 301b1c6 9a960ac df37aa3 9a960ac df37aa3 9a960ac ca1cdc2 df37aa3 9a960ac 5b27515 df37aa3 9a960ac df37aa3 9a960ac df37aa3 9a960ac df37aa3 9a960ac df37aa3 9a960ac df37aa3 9a960ac df37aa3 9a960ac df37aa3 9a960ac df37aa3 9a960ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
Gradio demo of image classification with OOD detection.
If the image example is probably OOD, the model will abstain from the prediction.
"""
import json
import logging
import pickle
from glob import glob
import gradio as gr
import numpy as np
import timm
import torch
import torch.nn.functional as F
from gradio.components import JSON, Image, Label
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
_logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
TOPK = 3
# load model
print("Loading model...")
model = timm.create_model("resnet50.tv2_in1k", pretrained=True)
model.to(device)
model.eval()
# dataset labels
idx2label = json.loads(open("ilsvrc2012.json").read())
idx2label = {int(k): v for k, v in idx2label.items()}
print(idx2label)
print(idx2label.values())
# transformation
config = resolve_data_config({}, model=model)
config["is_training"] = False
transform = create_transform(**config)
# create feature extractor
penultimate_features_key = "global_pool.flatten"
logits_key = "fc"
features_names = [penultimate_features_key, logits_key]
feature_extractor = create_feature_extractor(model, features_names)
# load centroids
centroids = torch.load("centroids_resnet50.tv2_in1k_igeood_logits.pt")
# OOD detector thresholds
msp_threshold = 0.3796
energy_threshold = 8
igeood_threshold = 2.4984
def mahalanobis_penult(features):
scores = torch.norm(features, dim=1, keepdims=True)
s = torch.min(scores, dim=1)[0]
return -s.item()
def msp(logits):
return torch.softmax(logits, dim=1).max(-1)[0].item()
def energy(logits):
return torch.logsumexp(logits, dim=1).item()
def igeoodlogits_vec(logits, temperature, centroids, epsilon=1e-12):
logits = torch.sqrt(F.softmax(logits / temperature, dim=1))
centroids = torch.sqrt(F.softmax(centroids / temperature, dim=1))
mult = logits @ centroids.T
stack = 2 * torch.acos(torch.clamp(mult, -1 + epsilon, 1 - epsilon))
return stack.mean(dim=1).item()
def predict(image):
# forward pass
inputs = transform(image).unsqueeze(0)
inputs = inputs.to(device)
with torch.no_grad():
features = feature_extractor(inputs)
# top 5 predictions
probabilities = torch.softmax(features[logits_key], dim=-1)
softmax, class_idxs = torch.topk(probabilities, TOPK)
_logger.info(softmax)
_logger.info(class_idxs)
result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())}
# OOD
msp_score = round(msp(features[logits_key]), 4)
energy_score = round(energy(features[logits_key]), 4)
igeood_scores = round(igeoodlogits_vec(features[logits_key], 1, centroids), 4)
ood_scores = {
"MSP": msp_score,
"MSP, is the input OOD?": msp_score < msp_threshold,
"Energy": energy_score,
"Energy, is the input OOD?": energy_score < energy_threshold,
"Igeood": igeood_scores,
"Igeood, is the input OOD?": igeood_scores < igeood_threshold,
}
_logger.info(ood_scores)
return result, ood_scores
def main():
# image examples for demo shuffled
examples = glob("images/imagenet/*") + glob("images/ood/*")
np.random.seed(42)
# np.random.shuffle(examples)
# gradio interface
interface = gr.Interface(
fn=predict,
inputs=Image(type="pil"),
outputs=[
Label(num_top_classes=TOPK, label="Model prediction"),
JSON(label="OOD scores"),
],
examples=examples,
examples_per_page=len(examples),
allow_flagging="never",
theme="default",
title="OOD Detection 🧐",
description=(
"Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. "
"The objective of an OOD detector is to determine wether the input sample comes from the distribution known by the AI model. "
"For instance, an input that does not belong to any of the known classes or is from a different domain should be flagged by the detector.\n"
"In this demo we will display the decision of three OOD detectors on a ResNet-50 model trained to classify on the ImageNet-1K dataset (top-1 accuracy 80%)."
"This model can classify among 1000 classes from several categories, including `animals`, `vehicles`, `clothing`, `instruments`, `plants`, etc. "
"For the complete hierarchy of classes, please check the website https://observablehq.com/@mbostock/imagenet-hierarchy. "
"\n\n"
"## Instructions:\n"
"1. Upload an image of your choice or select one from the examples bar.\n"
"2. The model will predict the top 3 most likely classes for the image.\n"
"3. The OOD detectors will output their scores and decision on the image. The smaller the score, the least confident the detector is on the sample being in-distribution.\n"
"4. If the image is OOD, the model will abstain from the prediction and flag it to the practicioner.\n"
"\n\n\nEnjoy the demo!"
),
cache_examples=True,
)
interface.launch(server_port=7860)
interface.close()
if __name__ == "__main__":
logging.basicConfig(level=logging.WARN)
gr.close_all()
main()
|