|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
from typing import Any, List |
|
|
|
import torch |
|
|
|
import llm_transparency_tool.routes.contributions as contributions |
|
|
|
|
|
class TestContributions(unittest.TestCase): |
|
def setUp(self): |
|
torch.manual_seed(123) |
|
|
|
self.eps = 1e-4 |
|
|
|
|
|
|
|
self.test_on_gpu = False |
|
|
|
self.device = "cuda" if self.test_on_gpu else "cpu" |
|
|
|
self.batch = 4 |
|
self.tokens = 5 |
|
self.heads = 6 |
|
self.d_model = 10 |
|
|
|
self.decomposed_attn = torch.rand( |
|
self.batch, |
|
self.tokens, |
|
self.tokens, |
|
self.heads, |
|
self.d_model, |
|
device=self.device, |
|
) |
|
self.mlp_out = torch.rand( |
|
self.batch, self.tokens, self.d_model, device=self.device |
|
) |
|
self.resid_pre = torch.rand( |
|
self.batch, self.tokens, self.d_model, device=self.device |
|
) |
|
self.resid_mid = torch.rand( |
|
self.batch, self.tokens, self.d_model, device=self.device |
|
) |
|
self.resid_post = torch.rand( |
|
self.batch, self.tokens, self.d_model, device=self.device |
|
) |
|
|
|
def _assert_tensor_eq(self, t: torch.Tensor, expected: List[Any]): |
|
self.assertTrue( |
|
torch.isclose(t, torch.Tensor(expected), atol=self.eps).all(), |
|
t, |
|
) |
|
|
|
def test_mlp_contributions(self): |
|
mlp_out = torch.tensor([[[1.0, 1.0]]]) |
|
resid_mid = torch.tensor([[[0.0, 0.0]]]) |
|
resid_post = torch.tensor([[[1.0, 1.0]]]) |
|
|
|
c_mlp, c_residual = contributions.get_mlp_contributions( |
|
resid_mid, resid_post, mlp_out |
|
) |
|
self.assertAlmostEqual(c_mlp.item(), 1.0, delta=self.eps) |
|
self.assertAlmostEqual(c_residual.item(), 0.0, delta=self.eps) |
|
|
|
def test_decomposed_attn_contributions(self): |
|
resid_pre = torch.tensor([[[2.0, 1.0]]]) |
|
resid_mid = torch.tensor([[[2.0, 2.0]]]) |
|
decomposed_attn = torch.tensor( |
|
[ |
|
[ |
|
[ |
|
[ |
|
[1.0, 1.0], |
|
[-1.0, 0.0], |
|
] |
|
] |
|
] |
|
] |
|
) |
|
|
|
c_attn, c_residual = contributions.get_attention_contributions( |
|
resid_pre, resid_mid, decomposed_attn, distance_norm=2 |
|
) |
|
self._assert_tensor_eq(c_attn, [[[[0.43613, 0]]]]) |
|
self.assertAlmostEqual(c_residual.item(), 0.56387, delta=self.eps) |
|
|
|
def test_decomposed_mlp_contributions(self): |
|
pre = torch.tensor([10.0, 10.0]) |
|
post = torch.tensor([-10.0, 10.0]) |
|
neuron_impacts = torch.tensor( |
|
[ |
|
[0.0, 1.0], |
|
[1.0, 0.0], |
|
[-21.0, -1.0], |
|
] |
|
) |
|
c_mlp, c_residual = contributions.get_decomposed_mlp_contributions( |
|
pre, post, neuron_impacts, distance_norm=2 |
|
) |
|
|
|
|
|
self._assert_tensor_eq(c_mlp, [1, 0, 0]) |
|
self.assertAlmostEqual(c_residual, 0, delta=self.eps) |
|
|
|
def test_decomposed_mlp_contributions_single_direction(self): |
|
pre = torch.tensor([1.0, 1.0]) |
|
post = torch.tensor([4.0, 4.0]) |
|
neuron_impacts = torch.tensor( |
|
[ |
|
[1.0, 1.0], |
|
[2.0, 2.0], |
|
] |
|
) |
|
c_mlp, c_residual = contributions.get_decomposed_mlp_contributions( |
|
pre, post, neuron_impacts, distance_norm=2 |
|
) |
|
self._assert_tensor_eq(c_mlp, [0.25, 0.5]) |
|
self.assertAlmostEqual(c_residual, 0.25, delta=self.eps) |
|
|
|
def test_attention_contributions_shape(self): |
|
c_attn, c_residual = contributions.get_attention_contributions( |
|
self.resid_pre, self.resid_mid, self.decomposed_attn |
|
) |
|
self.assertEqual( |
|
list(c_attn.shape), [self.batch, self.tokens, self.tokens, self.heads] |
|
) |
|
self.assertEqual(list(c_residual.shape), [self.batch, self.tokens]) |
|
|
|
def test_mlp_contributions_shape(self): |
|
c_mlp, c_residual = contributions.get_mlp_contributions( |
|
self.resid_mid, self.resid_post, self.mlp_out |
|
) |
|
self.assertEqual(list(c_mlp.shape), [self.batch, self.tokens]) |
|
self.assertEqual(list(c_residual.shape), [self.batch, self.tokens]) |
|
|
|
def test_renormalizing_threshold(self): |
|
c_blocks = torch.Tensor([[0.05, 0.15], [0.05, 0.05]]) |
|
c_residual = torch.Tensor([0.8, 0.9]) |
|
norm_blocks, norm_residual = contributions.apply_threshold_and_renormalize( |
|
0.1, c_blocks, c_residual |
|
) |
|
self._assert_tensor_eq(norm_blocks, [[0.0, 0.157894], [0.0, 0.0]]) |
|
self._assert_tensor_eq(norm_residual, [0.842105, 1.0]) |
|
|