MHN-React / mhnreact /plotutils.py
uragankatrrin's picture
Upload 12 files
2956799
raw
history blame
7.51 kB
# -*- coding: utf-8 -*-
"""
Author: Philipp Seidl
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
Johannes Kepler University Linz
Contact: [email protected]
Plot utils
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
plt.style.use('default')
def normal_approx_interval(p_hat, n, z=1.96):
""" approximating the distribution of error about a binomially-distributed observation, {\hat {p)), with a normal distribution
z = 1.96 --> alpha =0.05
z = 1 --> std
https://www.wikiwand.com/en/Binomial_proportion_confidence_interval"""
return z*((p_hat*(1-p_hat))/n)**(1/2)
our_colors = {
"lightblue": ( 0/255, 132/255, 187/255),
"red": (217/255, 92/255, 76/255),
"blue": ( 0/255, 132/255, 187/255),
"green": ( 91/255, 167/255, 85/255),
"yellow": (241/255, 188/255, 63/255),
"cyan": ( 79/255, 176/255, 191/255),
"grey": (125/255, 130/255, 140/255),
"lightgreen":(191/255, 206/255, 82/255),
"violett": (174/255, 97/255, 157/255),
}
def plot_std(p_hats, n_samples,z=1.96, color=our_colors['red'], alpha=0.2, xs=None):
p_hats = np.array(p_hats)
stds = np.array([normal_approx_interval(p_hats[ii], n_samples[ii], z=z) for ii in range(len(p_hats))])
xs = range(len(p_hats)) if xs is None else xs
plt.fill_between(xs, p_hats-(stds), p_hats+stds, color=color, alpha=alpha)
#plt.errorbar(range(13), asdf, [normal_approx_interval(asdf[ii], n_samples[ii], z=z) for ii in range(len(asdf))],
# c=our_colors['red'], linestyle='None', marker='.', ecolor=our_colors['red'])
def plot_loss(hist):
plt.plot(hist['step'], hist['loss'] )
plt.plot(hist['steps_valid'], np.array(hist['loss_valid']))
plt.legend(['train','validation'])
plt.xlabel('update-step')
plt.ylabel('loss (categorical-crossentropy-loss)')
def plot_topk(hist, sets=['train', 'valid', 'test'], with_last = 2):
ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400}
plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--')
for i in range(1,with_last):
for s in sets:
plt.plot(ks, [hist[f't{k}_acc_{s}'][-i] for k in ks],'.--', alpha=1/i)
plt.xlabel('top-k')
plt.ylabel('Accuracy')
plt.legend(sets)
plt.title('Hopfield-NN')
plt.ylim([-0.02,1])
def plot_nte(hist, dataset='Sm', last_cpt=1, include_bar=True, model_legend='MHN (ours)',
draw_std=True, z=1.96, n_samples=None, group_by_template_fp=False, schwaller_hist=None, fortunato_hist=None): #1.96 for 95%CI
markers = ['.']*4#['1','2','3','4']#['8','P','p','*']
lw = 2
ms = 8
k = 100
ntes = range(13)
if dataset=='Sm':
basel_values = [0. , 0.38424785, 0.66807858, 0.7916149 , 0.9051132 ,
0.92531258, 0.87295875, 0.94865587, 0.91830721, 0.95993717,
0.97215858, 0.9896713 , 0.99917817] #old basel_values = [0.0, 0.3882, 0.674, 0.7925, 0.9023, 0.9272, 0.874, 0.947, 0.9185, 0.959, 0.9717, 0.9927, 1.0]
pretr_values = [0.08439423, 0.70743412, 0.85555528, 0.95200267, 0.96513376,
0.96976397, 0.98373613, 0.99960286, 0.98683919, 0.96684724,
0.95907246, 0.9839079 , 0.98683919]# old [0.094, 0.711, 0.8584, 0.952, 0.9683, 0.9717, 0.988, 1.0, 1.0, 0.984, 0.9717, 1.0, 1.0]
staticQK = [0.2096, 0.1992, 0.2291, 0.1787, 0.2301, 0.1753, 0.2142, 0.2693, 0.2651, 0.1786, 0.2834, 0.5366, 0.6636]
if group_by_template_fp:
staticQK = [0.2651, 0.2617, 0.261 , 0.2181, 0.2622, 0.2393, 0.2157, 0.2184, 0.2 , 0.225 , 0.2039, 0.4568, 0.5293]
if dataset=='Lg':
pretr_values = [0.03410448, 0.65397054, 0.7254572 , 0.78969294, 0.81329924,
0.8651173 , 0.86775655, 0.8593128 , 0.88184124, 0.87764794,
0.89734215, 0.93328846, 0.99531597]
basel_values = [0. , 0.62478044, 0.68784314, 0.75089511, 0.77044644,
0.81229423, 0.82968149, 0.82965544, 0.83778338, 0.83049176,
0.8662873 , 0.92308414, 1.00042408]
#staticQK = [0.03638, 0.0339 , 0.03732, 0.03506, 0.03717, 0.0331 , 0.03003, 0.03613, 0.0304 , 0.02109, 0.0297 , 0.02632, 0.02217] # on 90k templates
staticQK = [0.006416,0.00686, 0.00616, 0.00825, 0.005085,0.006718,0.01041, 0.0015335,0.006668,0.004673,0.001706,0.02551,0.04074]
if dataset=='Golden':
staticQK = [0]*13
pretr_values = [0]*13
basel_values = [0]*13
if schwaller_hist:
midx = np.argmin(schwaller_hist['loss_valid'])
basel_values = ([schwaller_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
if fortunato_hist:
midx = np.argmin(fortunato_hist['loss_valid'])
pretr_values = ([fortunato_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
#hand_val = [0.0 , 0.4, 0.68, 0.79, 0.89, 0.91, 0.86, 0.9,0.88, 0.9, 0.93]
if include_bar:
if dataset=='Sm':
if n_samples is None:
n_samples = [610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648]
if group_by_template_fp:
n_samples = [460, 993, 433, 243, 183, 117, 102, 87, 110, 80, 103, 3048, 2203]
if dataset=='Lg':
if n_samples is None:
n_samples = [18861, 32226, 4220, 2546, 1573, 1191, 865, 652, 1350, 642, 586, 11638, 4958] #new
if group_by_template_fp:
n_samples = [13923, 17709, 7637, 4322, 2936, 2137, 1586, 1260, 1272, 1044, 829, 21695, 10559]
#[5169, 15904, 2814, 1853, 1238, 966, 766, 609, 1316, 664, 640, 30699, 21471]
#[13424,17246, 7681, 4332, 2844,2129,1698,1269, 1336,1067, 833, 22491, 11202] #grouped fp
plt.bar(range(11+2), np.array(n_samples)/sum(n_samples[:-1]), alpha=0.4, color=our_colors['grey'])
xti = [*[str(i) for i in range(11)], '>10', '>49']
asdf = []
for nte in xti:
try:
asdf.append( hist[f't{k}_acc_nte_{nte}'][-last_cpt])
except:
asdf.append(None)
plt.plot(range(13), asdf,f'{markers[3]}--', markersize=ms,c=our_colors['red'], linewidth=lw,alpha=1)
plt.plot(ntes, pretr_values,f'{markers[1]}--', c=our_colors['green'],
linewidth=lw, alpha=1,markersize=ms) #old [0.08, 0.7, 0.85, 0.9, 0.91, 0.95, 0.98, 0.97,0.98, 1, 1]
plt.plot(ntes, basel_values,f'{markers[0]}--',linewidth=lw,
c=our_colors['blue'], markersize=ms,alpha=1)
plt.plot(range(len(staticQK)), staticQK, f'{markers[2]}--',markersize=ms,c=our_colors['yellow'],linewidth=lw, alpha=1)
plt.title(f'USPTO-{dataset}')
plt.xlabel('number of training examples')
plt.ylabel('top-100 test-accuracy')
plt.legend([model_legend, 'Fortunato et al.','FNN baseline',"FPM baseline", #static${\\xi X}: \\dfrac{|{\\xi} \\cap {X}|}{|{X}|}$
'test sample proportion'])
if draw_std:
alpha=0.2
plot_std(asdf, n_samples, z=z, color=our_colors['red'], alpha=alpha)
plot_std(pretr_values, n_samples, z=z, color=our_colors['green'], alpha=alpha)
plot_std(basel_values, n_samples, z=z, color=our_colors['blue'], alpha=alpha)
plot_std(staticQK, n_samples, z=z, color=our_colors['yellow'], alpha=alpha)
plt.xticks(range(13),xti);
plt.yticks(np.arange(0,1.05,0.1))
plt.grid('on', alpha=0.3)