File size: 5,195 Bytes
ce00289 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Any, List
import torch
import llm_transparency_tool.routes.contributions as contributions
class TestContributions(unittest.TestCase):
def setUp(self):
torch.manual_seed(123)
self.eps = 1e-4
# It may be useful to run the test on GPU in case there are any issues with
# creating temporary tensors on another device. But turn this off by default.
self.test_on_gpu = False
self.device = "cuda" if self.test_on_gpu else "cpu"
self.batch = 4
self.tokens = 5
self.heads = 6
self.d_model = 10
self.decomposed_attn = torch.rand(
self.batch,
self.tokens,
self.tokens,
self.heads,
self.d_model,
device=self.device,
)
self.mlp_out = torch.rand(
self.batch, self.tokens, self.d_model, device=self.device
)
self.resid_pre = torch.rand(
self.batch, self.tokens, self.d_model, device=self.device
)
self.resid_mid = torch.rand(
self.batch, self.tokens, self.d_model, device=self.device
)
self.resid_post = torch.rand(
self.batch, self.tokens, self.d_model, device=self.device
)
def _assert_tensor_eq(self, t: torch.Tensor, expected: List[Any]):
self.assertTrue(
torch.isclose(t, torch.Tensor(expected), atol=self.eps).all(),
t,
)
def test_mlp_contributions(self):
mlp_out = torch.tensor([[[1.0, 1.0]]])
resid_mid = torch.tensor([[[0.0, 0.0]]])
resid_post = torch.tensor([[[1.0, 1.0]]])
c_mlp, c_residual = contributions.get_mlp_contributions(
resid_mid, resid_post, mlp_out
)
self.assertAlmostEqual(c_mlp.item(), 1.0, delta=self.eps)
self.assertAlmostEqual(c_residual.item(), 0.0, delta=self.eps)
def test_decomposed_attn_contributions(self):
resid_pre = torch.tensor([[[2.0, 1.0]]])
resid_mid = torch.tensor([[[2.0, 2.0]]])
decomposed_attn = torch.tensor(
[
[
[
[
[1.0, 1.0],
[-1.0, 0.0],
]
]
]
]
)
c_attn, c_residual = contributions.get_attention_contributions(
resid_pre, resid_mid, decomposed_attn, distance_norm=2
)
self._assert_tensor_eq(c_attn, [[[[0.43613, 0]]]])
self.assertAlmostEqual(c_residual.item(), 0.56387, delta=self.eps)
def test_decomposed_mlp_contributions(self):
pre = torch.tensor([10.0, 10.0])
post = torch.tensor([-10.0, 10.0])
neuron_impacts = torch.tensor(
[
[0.0, 1.0],
[1.0, 0.0],
[-21.0, -1.0],
]
)
c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
pre, post, neuron_impacts, distance_norm=2
)
# A bit counter-intuitive, but the only vector pointing from 0 towards the
# output is the first one.
self._assert_tensor_eq(c_mlp, [1, 0, 0])
self.assertAlmostEqual(c_residual, 0, delta=self.eps)
def test_decomposed_mlp_contributions_single_direction(self):
pre = torch.tensor([1.0, 1.0])
post = torch.tensor([4.0, 4.0])
neuron_impacts = torch.tensor(
[
[1.0, 1.0],
[2.0, 2.0],
]
)
c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
pre, post, neuron_impacts, distance_norm=2
)
self._assert_tensor_eq(c_mlp, [0.25, 0.5])
self.assertAlmostEqual(c_residual, 0.25, delta=self.eps)
def test_attention_contributions_shape(self):
c_attn, c_residual = contributions.get_attention_contributions(
self.resid_pre, self.resid_mid, self.decomposed_attn
)
self.assertEqual(
list(c_attn.shape), [self.batch, self.tokens, self.tokens, self.heads]
)
self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
def test_mlp_contributions_shape(self):
c_mlp, c_residual = contributions.get_mlp_contributions(
self.resid_mid, self.resid_post, self.mlp_out
)
self.assertEqual(list(c_mlp.shape), [self.batch, self.tokens])
self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
def test_renormalizing_threshold(self):
c_blocks = torch.Tensor([[0.05, 0.15], [0.05, 0.05]])
c_residual = torch.Tensor([0.8, 0.9])
norm_blocks, norm_residual = contributions.apply_threshold_and_renormalize(
0.1, c_blocks, c_residual
)
self._assert_tensor_eq(norm_blocks, [[0.0, 0.157894], [0.0, 0.0]])
self._assert_tensor_eq(norm_residual, [0.842105, 1.0])
|