File size: 5,452 Bytes
9a960ac
 
 
 
 
 
df37aa3
 
9a960ac
 
 
 
 
df37aa3
 
 
9a960ac
 
 
 
 
 
 
 
 
 
 
301b1c6
9a960ac
 
 
 
 
 
 
df37aa3
9a960ac
 
 
 
 
 
df37aa3
9a960ac
 
 
 
 
 
ca1cdc2
 
 
df37aa3
9a960ac
5b27515
df37aa3
9a960ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df37aa3
 
 
 
 
 
 
 
9a960ac
 
 
df37aa3
9a960ac
 
 
 
 
 
 
 
 
 
 
df37aa3
 
 
9a960ac
df37aa3
 
 
 
 
 
9a960ac
 
 
 
 
 
 
df37aa3
9a960ac
df37aa3
9a960ac
 
 
 
 
 
 
 
 
 
 
 
 
 
df37aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a960ac
df37aa3
9a960ac
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Gradio demo of image classification with OOD detection.

If the image example is probably OOD, the model will abstain from the prediction.
"""
import json
import logging
import pickle
from glob import glob

import gradio as gr
import numpy as np
import timm
import torch
import torch.nn.functional as F
from gradio.components import JSON, Image, Label
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

_logger = logging.getLogger(__name__)

device = "cuda" if torch.cuda.is_available() else "cpu"
TOPK = 3

# load model
print("Loading model...")
model = timm.create_model("resnet50.tv2_in1k", pretrained=True)
model.to(device)
model.eval()

# dataset labels
idx2label = json.loads(open("ilsvrc2012.json").read())
idx2label = {int(k): v for k, v in idx2label.items()}
print(idx2label)
print(idx2label.values())

# transformation
config = resolve_data_config({}, model=model)
config["is_training"] = False
transform = create_transform(**config)

# create feature extractor
penultimate_features_key = "global_pool.flatten"
logits_key = "fc"
features_names = [penultimate_features_key, logits_key]

feature_extractor = create_feature_extractor(model, features_names)


# load centroids
centroids = torch.load("centroids_resnet50.tv2_in1k_igeood_logits.pt")
# OOD detector thresholds
msp_threshold = 0.3796
energy_threshold = 8
igeood_threshold = 2.4984


def mahalanobis_penult(features):
    scores = torch.norm(features, dim=1, keepdims=True)
    s = torch.min(scores, dim=1)[0]
    return -s.item()


def msp(logits):
    return torch.softmax(logits, dim=1).max(-1)[0].item()


def energy(logits):
    return torch.logsumexp(logits, dim=1).item()


def igeoodlogits_vec(logits, temperature, centroids, epsilon=1e-12):
    logits = torch.sqrt(F.softmax(logits / temperature, dim=1))
    centroids = torch.sqrt(F.softmax(centroids / temperature, dim=1))
    mult = logits @ centroids.T
    stack = 2 * torch.acos(torch.clamp(mult, -1 + epsilon, 1 - epsilon))
    return stack.mean(dim=1).item()


def predict(image):
    # forward pass
    inputs = transform(image).unsqueeze(0)
    inputs = inputs.to(device)
    with torch.no_grad():
        features = feature_extractor(inputs)

    # top 5 predictions
    probabilities = torch.softmax(features[logits_key], dim=-1)
    softmax, class_idxs = torch.topk(probabilities, TOPK)
    _logger.info(softmax)
    _logger.info(class_idxs)

    result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())}
    # OOD
    msp_score = round(msp(features[logits_key]), 4)
    energy_score = round(energy(features[logits_key]), 4)
    igeood_scores = round(igeoodlogits_vec(features[logits_key], 1, centroids), 4)
    ood_scores = {
        "MSP": msp_score,
        "MSP, is the input OOD?": msp_score < msp_threshold,
        "Energy": energy_score,
        "Energy, is the input OOD?": energy_score < energy_threshold,
        "Igeood": igeood_scores,
        "Igeood, is the input OOD?": igeood_scores < igeood_threshold,
    }
    _logger.info(ood_scores)
    return result, ood_scores


def main():
    # image examples for demo shuffled
    examples = glob("images/imagenet/*") + glob("images/ood/*")
    np.random.seed(42)
    # np.random.shuffle(examples)

    # gradio interface
    interface = gr.Interface(
        fn=predict,
        inputs=Image(type="pil"),
        outputs=[
            Label(num_top_classes=TOPK, label="Model prediction"),
            JSON(label="OOD scores"),
        ],
        examples=examples,
        examples_per_page=len(examples),
        allow_flagging="never",
        theme="default",
        title="OOD Detection 🧐",
        description=(
            "Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. "
            "The objective of an OOD detector is to determine wether the input sample comes from the distribution known by the AI model. "
            "For instance, an input that does not belong to any of the known classes or is from a different domain should be flagged by the detector.\n"
            "In this demo we will display the decision of three OOD detectors on a ResNet-50 model trained to classify on the ImageNet-1K dataset (top-1 accuracy 80%)."
            "This model can classify among 1000 classes from several categories, including `animals`, `vehicles`, `clothing`, `instruments`, `plants`, etc. "
            "For the complete hierarchy of classes, please check the website https://observablehq.com/@mbostock/imagenet-hierarchy. "
            "\n\n"
            "## Instructions:\n"
            "1. Upload an image of your choice or select one from the examples bar.\n"
            "2. The model will predict the top 3 most likely classes for the image.\n"
            "3. The OOD detectors will output their scores and decision on the image. The smaller the score, the least confident the detector is on the sample being in-distribution.\n"
            "4. If the image is OOD, the model will abstain from the prediction and flag it to the practicioner.\n"
            "\n\n\nEnjoy the demo!"
        ),
        cache_examples=True,
    )
    interface.launch(server_port=7860)
    interface.close()


if __name__ == "__main__":
    logging.basicConfig(level=logging.WARN)

    gr.close_all()
    main()