from scipy.stats import spearmanr, pearsonr import numpy as np from tqdm.auto import tqdm def compute_spearman_correlation(attentions_list, saliency_file_address, desc="", aggregation="CLS", max_length=512): """ :param attentions_list: (#batch, #layers, sentence_len, sentence_len) :param saliency_file_address: :param desc: tqdm desc :param aggregation: CLS (Based on what affects CLS) | SUM (Based on the effect on all tokens) :return: spearmans (#batch, #layers, attender) """ saliencies = np.load(saliency_file_address) # pearsons = [] spearmans = [] if len(attentions_list[0].shape) == 2: # No layers attentions_list = [a.reshape(1, a.shape[0], a.shape[1]) for a in attentions_list] for i in tqdm(range(len(attentions_list)), desc=desc): i_spearmans = [] for layer in range(attentions_list[i].shape[0]): length = min(len(attentions_list[i][0]), max_length) # pearsons.append(pearsonr(attentions[i].sum(axis=0), saliencies[i][:length])[0]) if aggregation == "CLS": i_spearmans.append( spearmanr(attentions_list[i][layer][0][:length], saliencies[i][:length]).correlation) # CLS elif aggregation == "SUM": i_spearmans.append( spearmanr(attentions_list[i][layer].sum(axis=0)[:length], saliencies[i][:length]).correlation) else: raise Exception("Undefined aggregation method. Possible values: CLS, SUM") spearmans.append(np.array(i_spearmans)) return spearmans def compute_spearman_correlation_hta(attentions_list, hta_file_address, desc="", max_length=512): """ :param attentions_list: (256, 12, seq_len, seq_len) :param hta_file_address: (12, 256, 64, 64) :param desc: :param max_length: :return: (256, 12, seq_len) = (batch, layers, attender) """ hta = np.load(hta_file_address) spearmans = [] if len(attentions_list[0].shape) == 2: # No layers attentions_list = [a.reshape(1, a.shape[0], a.shape[1]) for a in attentions_list] # len(attentions_list) for i in tqdm(range(len(attentions_list)), desc=desc): i_spearmans = [] length = min(len(attentions_list[i][0]), max_length) for layer in range(attentions_list[i].shape[0]): i_layer_spearmans = [] for attender in range(length): i_layer_spearmans.append(spearmanr(attentions_list[i][layer][attender][:length], hta[layer][i][attender][:length]).correlation) i_spearmans.append(np.array(i_layer_spearmans)) spearmans.append(np.array(i_spearmans)) return spearmans