Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
Author: Philipp Seidl | |
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning | |
Johannes Kepler University Linz | |
Contact: [email protected] | |
Plot utils | |
""" | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from matplotlib import pyplot as plt | |
plt.style.use('default') | |
def normal_approx_interval(p_hat, n, z=1.96): | |
""" approximating the distribution of error about a binomially-distributed observation, {\hat {p)), with a normal distribution | |
z = 1.96 --> alpha =0.05 | |
z = 1 --> std | |
https://www.wikiwand.com/en/Binomial_proportion_confidence_interval""" | |
return z*((p_hat*(1-p_hat))/n)**(1/2) | |
our_colors = { | |
"lightblue": ( 0/255, 132/255, 187/255), | |
"red": (217/255, 92/255, 76/255), | |
"blue": ( 0/255, 132/255, 187/255), | |
"green": ( 91/255, 167/255, 85/255), | |
"yellow": (241/255, 188/255, 63/255), | |
"cyan": ( 79/255, 176/255, 191/255), | |
"grey": (125/255, 130/255, 140/255), | |
"lightgreen":(191/255, 206/255, 82/255), | |
"violett": (174/255, 97/255, 157/255), | |
} | |
def plot_std(p_hats, n_samples,z=1.96, color=our_colors['red'], alpha=0.2, xs=None): | |
p_hats = np.array(p_hats) | |
stds = np.array([normal_approx_interval(p_hats[ii], n_samples[ii], z=z) for ii in range(len(p_hats))]) | |
xs = range(len(p_hats)) if xs is None else xs | |
plt.fill_between(xs, p_hats-(stds), p_hats+stds, color=color, alpha=alpha) | |
#plt.errorbar(range(13), asdf, [normal_approx_interval(asdf[ii], n_samples[ii], z=z) for ii in range(len(asdf))], | |
# c=our_colors['red'], linestyle='None', marker='.', ecolor=our_colors['red']) | |
def plot_loss(hist): | |
plt.plot(hist['step'], hist['loss'] ) | |
plt.plot(hist['steps_valid'], np.array(hist['loss_valid'])) | |
plt.legend(['train','validation']) | |
plt.xlabel('update-step') | |
plt.ylabel('loss (categorical-crossentropy-loss)') | |
def plot_topk(hist, sets=['train', 'valid', 'test'], with_last = 2): | |
ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100] | |
baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400} | |
plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--') | |
for i in range(1,with_last): | |
for s in sets: | |
plt.plot(ks, [hist[f't{k}_acc_{s}'][-i] for k in ks],'.--', alpha=1/i) | |
plt.xlabel('top-k') | |
plt.ylabel('Accuracy') | |
plt.legend(sets) | |
plt.title('Hopfield-NN') | |
plt.ylim([-0.02,1]) | |
def plot_nte(hist, dataset='Sm', last_cpt=1, include_bar=True, model_legend='MHN (ours)', | |
draw_std=True, z=1.96, n_samples=None, group_by_template_fp=False, schwaller_hist=None, fortunato_hist=None): #1.96 for 95%CI | |
markers = ['.']*4#['1','2','3','4']#['8','P','p','*'] | |
lw = 2 | |
ms = 8 | |
k = 100 | |
ntes = range(13) | |
if dataset=='Sm': | |
basel_values = [0. , 0.38424785, 0.66807858, 0.7916149 , 0.9051132 , | |
0.92531258, 0.87295875, 0.94865587, 0.91830721, 0.95993717, | |
0.97215858, 0.9896713 , 0.99917817] #old basel_values = [0.0, 0.3882, 0.674, 0.7925, 0.9023, 0.9272, 0.874, 0.947, 0.9185, 0.959, 0.9717, 0.9927, 1.0] | |
pretr_values = [0.08439423, 0.70743412, 0.85555528, 0.95200267, 0.96513376, | |
0.96976397, 0.98373613, 0.99960286, 0.98683919, 0.96684724, | |
0.95907246, 0.9839079 , 0.98683919]# old [0.094, 0.711, 0.8584, 0.952, 0.9683, 0.9717, 0.988, 1.0, 1.0, 0.984, 0.9717, 1.0, 1.0] | |
staticQK = [0.2096, 0.1992, 0.2291, 0.1787, 0.2301, 0.1753, 0.2142, 0.2693, 0.2651, 0.1786, 0.2834, 0.5366, 0.6636] | |
if group_by_template_fp: | |
staticQK = [0.2651, 0.2617, 0.261 , 0.2181, 0.2622, 0.2393, 0.2157, 0.2184, 0.2 , 0.225 , 0.2039, 0.4568, 0.5293] | |
if dataset=='Lg': | |
pretr_values = [0.03410448, 0.65397054, 0.7254572 , 0.78969294, 0.81329924, | |
0.8651173 , 0.86775655, 0.8593128 , 0.88184124, 0.87764794, | |
0.89734215, 0.93328846, 0.99531597] | |
basel_values = [0. , 0.62478044, 0.68784314, 0.75089511, 0.77044644, | |
0.81229423, 0.82968149, 0.82965544, 0.83778338, 0.83049176, | |
0.8662873 , 0.92308414, 1.00042408] | |
#staticQK = [0.03638, 0.0339 , 0.03732, 0.03506, 0.03717, 0.0331 , 0.03003, 0.03613, 0.0304 , 0.02109, 0.0297 , 0.02632, 0.02217] # on 90k templates | |
staticQK = [0.006416,0.00686, 0.00616, 0.00825, 0.005085,0.006718,0.01041, 0.0015335,0.006668,0.004673,0.001706,0.02551,0.04074] | |
if dataset=='Golden': | |
staticQK = [0]*13 | |
pretr_values = [0]*13 | |
basel_values = [0]*13 | |
if schwaller_hist: | |
midx = np.argmin(schwaller_hist['loss_valid']) | |
basel_values = ([schwaller_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']]) | |
if fortunato_hist: | |
midx = np.argmin(fortunato_hist['loss_valid']) | |
pretr_values = ([fortunato_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']]) | |
#hand_val = [0.0 , 0.4, 0.68, 0.79, 0.89, 0.91, 0.86, 0.9,0.88, 0.9, 0.93] | |
if include_bar: | |
if dataset=='Sm': | |
if n_samples is None: | |
n_samples = [610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648] | |
if group_by_template_fp: | |
n_samples = [460, 993, 433, 243, 183, 117, 102, 87, 110, 80, 103, 3048, 2203] | |
if dataset=='Lg': | |
if n_samples is None: | |
n_samples = [18861, 32226, 4220, 2546, 1573, 1191, 865, 652, 1350, 642, 586, 11638, 4958] #new | |
if group_by_template_fp: | |
n_samples = [13923, 17709, 7637, 4322, 2936, 2137, 1586, 1260, 1272, 1044, 829, 21695, 10559] | |
#[5169, 15904, 2814, 1853, 1238, 966, 766, 609, 1316, 664, 640, 30699, 21471] | |
#[13424,17246, 7681, 4332, 2844,2129,1698,1269, 1336,1067, 833, 22491, 11202] #grouped fp | |
plt.bar(range(11+2), np.array(n_samples)/sum(n_samples[:-1]), alpha=0.4, color=our_colors['grey']) | |
xti = [*[str(i) for i in range(11)], '>10', '>49'] | |
asdf = [] | |
for nte in xti: | |
try: | |
asdf.append( hist[f't{k}_acc_nte_{nte}'][-last_cpt]) | |
except: | |
asdf.append(None) | |
plt.plot(range(13), asdf,f'{markers[3]}--', markersize=ms,c=our_colors['red'], linewidth=lw,alpha=1) | |
plt.plot(ntes, pretr_values,f'{markers[1]}--', c=our_colors['green'], | |
linewidth=lw, alpha=1,markersize=ms) #old [0.08, 0.7, 0.85, 0.9, 0.91, 0.95, 0.98, 0.97,0.98, 1, 1] | |
plt.plot(ntes, basel_values,f'{markers[0]}--',linewidth=lw, | |
c=our_colors['blue'], markersize=ms,alpha=1) | |
plt.plot(range(len(staticQK)), staticQK, f'{markers[2]}--',markersize=ms,c=our_colors['yellow'],linewidth=lw, alpha=1) | |
plt.title(f'USPTO-{dataset}') | |
plt.xlabel('number of training examples') | |
plt.ylabel('top-100 test-accuracy') | |
plt.legend([model_legend, 'Fortunato et al.','FNN baseline',"FPM baseline", #static${\\xi X}: \\dfrac{|{\\xi} \\cap {X}|}{|{X}|}$ | |
'test sample proportion']) | |
if draw_std: | |
alpha=0.2 | |
plot_std(asdf, n_samples, z=z, color=our_colors['red'], alpha=alpha) | |
plot_std(pretr_values, n_samples, z=z, color=our_colors['green'], alpha=alpha) | |
plot_std(basel_values, n_samples, z=z, color=our_colors['blue'], alpha=alpha) | |
plot_std(staticQK, n_samples, z=z, color=our_colors['yellow'], alpha=alpha) | |
plt.xticks(range(13),xti); | |
plt.yticks(np.arange(0,1.05,0.1)) | |
plt.grid('on', alpha=0.3) |