ece / tests.py
jordyvl's picture
update app
3257c6c
raw
history blame
No virus
1.16 kB
import numpy as np
test_cases = [
{
"predictions": [0, 0],
"references": [1, 1],
"result": {"metric_score": 0}
},
{
"predictions": [1, 1],
"references": [1, 1],
"result": {"metric_score": 1}
},
{
"predictions": [1, 0],
"references": [1, 1],
"result": {"metric_score": 0.5}
}
]
def test_ECE():
N = 10 # N evaluation instances {(x_i,y_i)}_{i=1}^N
K = 5 # K class problem
def random_mc_instance(concentration=1, onehot=False):
reference = np.argmax(
np.random.dirichlet(([concentration for _ in range(K)])), -1
) # class targets
prediction = np.random.dirichlet(([concentration for _ in range(K)])) # probabilities
if onehot:
reference = np.eye(K)[np.argmax(reference, -1)]
return reference, prediction
references, predictions = list(zip(*[random_mc_instance() for i in range(N)]))
references = np.array(references, dtype=np.int64)
predictions = np.array(predictions, dtype=np.float32)
res = ECE()._compute(predictions, references)
print(f"ECE: {res['ECE']}")