Spaces:
Runtime error
Runtime error
import evaluate | |
import json | |
import sys | |
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import ast | |
# from ece import ECE # loads local instead | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
""" | |
import seaborn as sns | |
sns.set_style('white') | |
sns.set_context("paper", font_scale=1) | |
""" | |
# plt.rcParams['figure.figsize'] = [10, 7] | |
plt.rcParams["figure.dpi"] = 300 | |
plt.switch_backend( | |
"agg" | |
) # ; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop | |
sliders = [ | |
gr.Slider(0, 100, value=10, label="n_bins"), | |
gr.Slider( | |
0, 100, value=None, label="bin_range", visible=False | |
), # DEV: need to have a double slider | |
gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"), | |
gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"), | |
gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"), | |
] | |
slider_defaults = [slider.value for slider in sliders] | |
# example data | |
df = dict() | |
df["predictions"] = [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.7, 0.1, 0.2]] | |
df["references"] = [0, 1, 2] | |
component = gr.Dataframe( | |
headers=["predictions", "references"], col_count=2, datatype="number", type="pandas" | |
) | |
component.value = [ | |
[[0.6, 0.2, 0.2], 0], | |
[[0.7, 0.1, 0.2], 2], | |
[[0, 0.95, 0.05], 1], | |
] | |
sample_data = [[component] + slider_defaults] ##json.dumps(df) | |
local_path = Path(sys.path[0]) | |
metric = evaluate.load("jordyvl/ece") | |
# ECE() | |
# module = evaluate.load("jordyvl/ece") | |
# launch_gradio_widget(module) | |
"""l | |
Switch inputs and compute_fn | |
""" | |
def default_plot(): | |
fig = plt.figure() | |
ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) | |
ax2 = plt.subplot2grid((3, 1), (2, 0)) | |
ranged = np.linspace(0, 1, 10) | |
ax1.plot( | |
ranged, | |
ranged, | |
color="darkgreen", | |
ls="dotted", | |
label="Perfect", | |
) | |
# Bin differences | |
ax1.set_ylabel("Conditional Expectation") | |
ax1.set_ylim([0, 1.05]) # respective to bin range | |
ax1.set_title("Reliability Diagram") | |
ax1.set_xlim([-0.05, 1.05]) # respective to bin range | |
# Bin frequencies | |
ax2.set_xlabel("Confidence") | |
ax2.set_ylabel("Count") | |
ax2.legend(loc="upper left") # , ncol=2 | |
ax2.set_xlim([-0.05, 1.05]) # respective to bin range | |
return fig, ax1, ax2 | |
def reliability_plot(results): | |
# DEV: might still need to write tests in case of equal mass binning | |
# DEV: nicer would be to plot like a polygon | |
# see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py | |
def over_under_confidence(bins, patches): | |
colors = [] | |
for j, bin in enumerate(bins): | |
perfect = bin | |
if j == len(patches): | |
j = len(patches) -1 | |
empirical = patches[j].get_height() | |
bin_color = ( | |
"limegreen" | |
if np.allclose(perfect, empirical) | |
else "dodgerblue" | |
if empirical < perfect | |
else "orangered" | |
) | |
colors.append(bin_color) | |
return colors | |
fig, ax1, ax2 = default_plot() | |
# Bin differences | |
bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0) | |
B, bins, patches = ax1.hist( | |
results["y_bar"], | |
weights=np.nan_to_num(results["p_bar"][:-1], copy=True, nan=0), | |
bins=bins_with_left_edge, | |
) | |
colors = over_under_confidence(bins, patches) | |
for b in range(len(B)): | |
patches[b].set_facecolor(colors[b]) # color based on over/underconfidence | |
ax1handles = [ | |
mpatches.Patch(color="orangered", label="Overconfident"), | |
mpatches.Patch(color="limegreen", label="Perfect", linestyle="dotted"), | |
mpatches.Patch(color="dodgerblue", label="Underconfident"), | |
] | |
# Bin frequencies | |
anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0] | |
n_bins = len(results["y_bar"]) | |
bin_freqs = np.zeros(n_bins) | |
bin_freqs[anindices] = results["bin_freq"] | |
B, newbins, patches = ax2.hist( | |
results["y_bar"], weights=bin_freqs, color="midnightblue", bins=bins_with_left_edge | |
) | |
acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy") | |
conf_plt = ax2.axvline( | |
x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence" | |
) | |
ax1.legend(loc="lower right", handles=ax1handles) | |
ax2.legend(handles=[acc_plt, conf_plt]) | |
ax1.set_xticks(bins_with_left_edge) | |
ax2.set_xticks(bins_with_left_edge) | |
plt.tight_layout() | |
return fig | |
def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p): | |
# DEV: check on invalid datatypes with better warnings | |
if isinstance(data, pd.DataFrame): | |
data.dropna(inplace=True) | |
predictions = [ | |
ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction | |
for prediction in data["predictions"] | |
] | |
references = [reference for reference in data["references"]] | |
results = metric._compute( | |
predictions, | |
references, | |
n_bins=n_bins, | |
scheme=scheme, | |
proxy=proxy, | |
p=p, | |
detail=True, | |
) | |
print(results) | |
plot = reliability_plot(results) | |
return results["ECE"], plot | |
outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")] | |
# outputs[1].value = default_plot().__dict__ #Does not work; yet needs to be JSON encoded | |
iface = gr.Interface( | |
fn=compute_and_plot, | |
inputs=[component] + sliders, | |
outputs=outputs, | |
description=metric.info.description, | |
article=evaluate.utils.parse_readme(local_path / "README.md"), | |
title=f"Metric: {metric.name}", | |
# examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs. | |
).launch() |