File size: 9,467 Bytes
ce00289 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Optional
import torch
import transformer_lens
import transformers
from fancy_einsum import einsum
from jaxtyping import Float, Int
from typeguard import typechecked
import streamlit as st
from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm
@dataclass
class _RunInfo:
tokens: Int[torch.Tensor, "batch pos"]
logits: Float[torch.Tensor, "batch pos d_vocab"]
cache: transformer_lens.ActivationCache
@st.cache_resource(
max_entries=1,
show_spinner=True,
hash_funcs={
transformers.PreTrainedModel: id,
transformers.PreTrainedTokenizer: id
}
)
def load_hooked_transformer(
model_name: str,
hf_model: Optional[transformers.PreTrainedModel] = None,
tlens_device: str = "cuda",
dtype: torch.dtype = torch.float32,
):
# if tlens_device == "cuda":
# n_devices = torch.cuda.device_count()
# else:
# n_devices = 1
tlens_model = transformer_lens.HookedTransformer.from_pretrained(
model_name,
hf_model=hf_model,
fold_ln=False, # Keep layer norm where it is.
center_writing_weights=False,
center_unembed=False,
device=tlens_device,
# n_devices=n_devices,
dtype=dtype,
)
tlens_model.eval()
return tlens_model
# TODO(igortufanov): If we want to scale the app to multiple users, we need more careful
# thread-safe implementation. The simplest option could be to wrap the existing methods
# in mutexes.
class TransformerLensTransparentLlm(TransparentLlm):
"""
Implementation of Transparent LLM based on transformer lens.
Args:
- model_name: The official name of the model from HuggingFace. Even if the model was
patched or loaded locally, the name should still be official because that's how
transformer_lens treats the model.
- hf_model: The language model as a HuggingFace class.
- tokenizer,
- device: "gpu" or "cpu"
"""
def __init__(
self,
model_name: str,
hf_model: Optional[transformers.PreTrainedModel] = None,
tokenizer: Optional[transformers.PreTrainedTokenizer] = None,
device: str = "gpu",
dtype: torch.dtype = torch.float32,
):
if device == "gpu":
self.device = "cuda"
if not torch.cuda.is_available():
RuntimeError("Asked to run on gpu, but torch couldn't find cuda")
elif device == "cpu":
self.device = "cpu"
else:
raise RuntimeError(f"Specified device {device} is not a valid option")
self.dtype = dtype
self.hf_tokenizer = tokenizer
self.hf_model = hf_model
# self._model = tlens_model
self._model_name = model_name
self._prepend_bos = True
self._last_run = None
self._run_exception = RuntimeError(
"Tried to use the model output before calling the `run` method"
)
def copy(self):
import copy
return copy.copy(self)
@property
def _model(self):
tlens_model = load_hooked_transformer(
self._model_name,
hf_model=self.hf_model,
tlens_device=self.device,
dtype=self.dtype,
)
if self.hf_tokenizer is not None:
tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left")
tlens_model.set_use_attn_result(True)
tlens_model.set_use_attn_in(False)
tlens_model.set_use_split_qkv_input(False)
return tlens_model
def model_info(self) -> ModelInfo:
cfg = self._model.cfg
return ModelInfo(
name=self._model_name,
n_params_estimate=cfg.n_params,
n_layers=cfg.n_layers,
n_heads=cfg.n_heads,
d_model=cfg.d_model,
d_vocab=cfg.d_vocab,
)
@torch.no_grad()
def run(self, sentences: List[str]) -> None:
tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos)
logits, cache = self._model.run_with_cache(tokens)
self._last_run = _RunInfo(
tokens=tokens,
logits=logits,
cache=cache,
)
def batch_size(self) -> int:
if not self._last_run:
raise self._run_exception
return self._last_run.logits.shape[0]
@typechecked
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
if not self._last_run:
raise self._run_exception
return self._last_run.tokens
@typechecked
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
return self._model.to_str_tokens(tokens)
@typechecked
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
if not self._last_run:
raise self._run_exception
return self._last_run.logits
@torch.no_grad()
@typechecked
def unembed(
self,
t: Float[torch.Tensor, "d_model"],
normalize: bool,
) -> Float[torch.Tensor, "vocab"]:
# t: [d_model] -> [batch, pos, d_model]
tdim = t.unsqueeze(0).unsqueeze(0)
if normalize:
normalized = self._model.ln_final(tdim)
result = self._model.unembed(normalized)
else:
result = self._model.unembed(tdim)
return result[0][0]
def _get_block(self, layer: int, block_name: str) -> str:
if not self._last_run:
raise self._run_exception
return self._last_run.cache[f"blocks.{layer}.{block_name}"]
# ================= Methods related to the residual stream =================
@typechecked
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
if not self._last_run:
raise self._run_exception
return self._get_block(layer, "hook_resid_pre")
@typechecked
def residual_after_attn(
self, layer: int
) -> Float[torch.Tensor, "batch pos d_model"]:
if not self._last_run:
raise self._run_exception
return self._get_block(layer, "hook_resid_mid")
@typechecked
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
if not self._last_run:
raise self._run_exception
return self._get_block(layer, "hook_resid_post")
# ================ Methods related to the feed-forward layer ===============
@typechecked
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
if not self._last_run:
raise self._run_exception
return self._get_block(layer, "hook_mlp_out")
@torch.no_grad()
@typechecked
def decomposed_ffn_out(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "hidden d_model"]:
# Take activations right before they're multiplied by W_out, i.e. non-linearity
# and layer norm are already applied.
processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos]
return torch.mul(processed_activations.unsqueeze(-1), self._model.W_out[layer])
@typechecked
def neuron_activations(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "hidden"]:
return self._get_block(layer, "mlp.hook_pre")[batch_i][pos]
@typechecked
def neuron_output(
self,
layer: int,
neuron: int,
) -> Float[torch.Tensor, "d_model"]:
return self._model.W_out[layer][neuron]
# ==================== Methods related to the attention ====================
@typechecked
def attention_matrix(
self, batch_i: int, layer: int, head: int
) -> Float[torch.Tensor, "query_pos key_pos"]:
return self._get_block(layer, "attn.hook_pattern")[batch_i][head]
@typechecked
def attention_output_per_head(
self,
batch_i: int,
layer: int,
pos: int,
head: int,
) -> Float[torch.Tensor, "d_model"]:
return self._get_block(layer, "attn.hook_result")[batch_i][pos][head]
@typechecked
def attention_output(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "d_model"]:
return self._get_block(layer, "hook_attn_out")[batch_i][pos]
@torch.no_grad()
@typechecked
def decomposed_attn(
self, batch_i: int, layer: int
) -> Float[torch.Tensor, "pos key_pos head d_model"]:
if not self._last_run:
raise self._run_exception
hook_v = self._get_block(layer, "attn.hook_v")[batch_i]
b_v = self._model.b_V[layer]
v = hook_v + b_v
pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype)
z = einsum(
"key_pos head d_head, "
"head query_pos key_pos -> "
"query_pos key_pos head d_head",
v,
pattern,
)
decomposed_attn = einsum(
"pos key_pos head d_head, "
"head d_head d_model -> "
"pos key_pos head d_model",
z,
self._model.W_O[layer],
)
return decomposed_attn
|