Spaces:
Running
Running
"""Util functions for codebook features.""" | |
import pathlib | |
import re | |
import typing | |
from dataclasses import dataclass | |
from functools import partial | |
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from termcolor import colored | |
from tqdm import tqdm | |
class CodeInfo: | |
"""Dataclass for codebook info.""" | |
code: int | |
layer: int | |
head: Optional[int] | |
cb_at: Optional[str] = None | |
# for patching interventions | |
pos: Optional[int] = None | |
code_pos: Optional[int] = -1 | |
# for description & regex-based interpretation | |
description: Optional[str] = None | |
regex: Optional[str] = None | |
prec: Optional[float] = None | |
recall: Optional[float] = None | |
num_acts: Optional[int] = None | |
def __post_init__(self): | |
"""Convert to appropriate types.""" | |
self.code = int(self.code) | |
self.layer = int(self.layer) | |
if self.head: | |
self.head = int(self.head) | |
if self.pos: | |
self.pos = int(self.pos) | |
if self.code_pos: | |
self.code_pos = int(self.code_pos) | |
if self.prec: | |
self.prec = float(self.prec) | |
assert 0 <= self.prec <= 1 | |
if self.recall: | |
self.recall = float(self.recall) | |
assert 0 <= self.recall <= 1 | |
if self.num_acts: | |
self.num_acts = int(self.num_acts) | |
def check_description_info(self): | |
"""Check if the regex info is present.""" | |
assert self.num_acts is not None and self.description is not None | |
if self.regex is not None: | |
assert self.prec is not None and self.recall is not None | |
def __repr__(self): | |
"""Return the string representation.""" | |
repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}" | |
if self.pos is not None or self.code_pos is not None: | |
repr += f", pos={self.pos}, code_pos={self.code_pos}" | |
if self.description is not None: | |
repr += f", description={self.description}" | |
if self.regex is not None: | |
repr += f", regex={self.regex}, prec={self.prec}, recall={self.recall}" | |
if self.num_acts is not None: | |
repr += f", num_acts={self.num_acts}" | |
repr += ")" | |
return repr | |
def from_str(cls, code_txt, *args, **kwargs): | |
"""Extract code info fields from string.""" | |
code_txt = code_txt.strip().lower() | |
code_txt = code_txt.split(", ") | |
code_txt = dict(txt.split(": ") for txt in code_txt) | |
return cls(*args, **code_txt, **kwargs) | |
class ModelInfoForWebapp: | |
"""Model info for webapp.""" | |
model_name: str | |
pretrained_path: str | |
dataset_name: str | |
num_codes: int | |
cb_at: str | |
gcb: str | |
n_layers: int | |
n_heads: Optional[int] = None | |
seed: int = 42 | |
max_samples: int = 2000 | |
def __post_init__(self): | |
"""Convert to correct types.""" | |
self.num_codes = int(self.num_codes) | |
self.n_layers = int(self.n_layers) | |
if self.n_heads == "None": | |
self.n_heads = None | |
elif self.n_heads is not None: | |
self.n_heads = int(self.n_heads) | |
self.seed = int(self.seed) | |
self.max_samples = int(self.max_samples) | |
def load(cls, path): | |
"""Parse model info from path.""" | |
path = pathlib.Path(path) | |
with open(path / "info.txt", "r") as f: | |
lines = f.readlines() | |
lines = dict(line.strip().split(": ") for line in lines) | |
return cls(**lines) | |
def save(self, path): | |
"""Save model info to path.""" | |
path = pathlib.Path(path) | |
with open(path / "info.txt", "w") as f: | |
for k, v in self.__dict__.items(): | |
f.write(f"{k}: {v}\n") | |
def logits_to_pred(logits, tokenizer, k=5): | |
"""Convert logits to top-k predictions.""" | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
probs = sorted_logits.softmax(dim=-1) | |
topk_preds = [tokenizer.convert_ids_to_tokens(e) for e in sorted_indices[:, -1, :k]] | |
topk_preds = [ | |
tokenizer.convert_tokens_to_string([e]) for batch in topk_preds for e in batch | |
] | |
return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))] | |
def features_to_tokens(cb_key, cb_acts, num_codes, code=None): | |
"""Return the set of token ids each codebook feature activates on.""" | |
codebook_ids = cb_acts[cb_key] | |
if code is None: | |
features_tokens = [[] for _ in range(num_codes)] | |
for i in tqdm(range(codebook_ids.shape[0])): | |
for j in range(codebook_ids.shape[1]): | |
for k in range(codebook_ids.shape[2]): | |
features_tokens[codebook_ids[i, j, k]].append((i, j)) | |
else: | |
idx0, idx1, _ = np.where(codebook_ids == code) | |
features_tokens = list(zip(idx0, idx1)) | |
return features_tokens | |
def color_str(s: str, html: bool, color: Optional[str] = None): | |
"""Color the string for html or terminal.""" | |
if html: | |
color = "DeepSkyBlue" if color is None else color | |
return f"<span style='color:{color}'>{s}</span>" | |
else: | |
color = "light_cyan" if color is None else color | |
return colored(s, color) | |
def color_tokens_tokfsm(tokens, color_idx, html=False): | |
"""Separate states with a dash and color red the tokens in color_idx.""" | |
ret_string = "" | |
itr_over_color_idx = 0 | |
tokens_enumerate = enumerate(tokens) | |
if tokens[0] == "<|endoftext|>": | |
next(tokens_enumerate) | |
if color_idx[0] == 0: | |
itr_over_color_idx += 1 | |
for i, c in tokens_enumerate: | |
if i % 2 == 1: | |
ret_string += "-" | |
if itr_over_color_idx < len(color_idx) and i == color_idx[itr_over_color_idx]: | |
ret_string += color_str(c, html) | |
itr_over_color_idx += 1 | |
else: | |
ret_string += c | |
return ret_string | |
def color_tokens(tokens, color_idx, n=3, html=False): | |
"""Color the tokens in color_idx.""" | |
ret_string = "" | |
last_colored_token_idx = -1 | |
for i in color_idx: | |
c_str = tokens[i] | |
if i <= last_colored_token_idx + 2 * n + 1: | |
ret_string += "".join(tokens[last_colored_token_idx + 1 : i]) | |
else: | |
ret_string += "".join( | |
tokens[last_colored_token_idx + 1 : last_colored_token_idx + n + 1] | |
) | |
ret_string += " ... " | |
ret_string += "".join(tokens[i - n : i]) | |
ret_string += color_str(c_str, html) | |
last_colored_token_idx = i | |
ret_string += "".join( | |
tokens[ | |
last_colored_token_idx + 1 : min(last_colored_token_idx + n, len(tokens)) | |
] | |
) | |
return ret_string | |
def prepare_example_print( | |
example_id, | |
example_tokens, | |
tokens_to_color, | |
html, | |
color_fn=color_tokens, | |
): | |
"""Format example to print.""" | |
example_output = color_str(example_id, html, "green") | |
example_output += ( | |
": " | |
+ color_fn(example_tokens, tokens_to_color, html=html) | |
+ ("<br>" if html else "\n") | |
) | |
return example_output | |
def print_token_activations_of_code( | |
code_act_by_pos, | |
tokens, | |
is_fsm=False, | |
n=3, | |
max_examples=100, | |
randomize=False, | |
html=False, | |
return_example_list=False, | |
): | |
"""Print the context with the tokens that a code activates on. | |
Args: | |
code_act_by_pos: list of (example_id, token_pos_id) tuples specifying | |
the token positions that a code activates on in a dataset. | |
tokens: list of tokens of a dataset. | |
is_fsm: whether the dataset is the TokFSM dataset. | |
n: context to print around each side of a token that the code activates on. | |
max_examples: maximum number of examples to print. | |
randomize: whether to randomize the order of examples. | |
html: Format the printing style for html or terminal. | |
return_example_list: whether to return the printed string by examples or as a single string. | |
Returns: | |
string of all examples formatted if `return_example_list` is False otherwise | |
list of (example_string, num_tokens_colored) tuples for each example. | |
""" | |
if randomize: | |
raise NotImplementedError("Randomize not yet implemented.") | |
indices = range(len(code_act_by_pos)) | |
print_output = [] if return_example_list else "" | |
curr_ex = code_act_by_pos[0][0] | |
total_examples = 0 | |
tokens_to_color = [] | |
color_fn = color_tokens_tokfsm if is_fsm else partial(color_tokens, n=n) | |
for idx in indices: | |
if total_examples > max_examples: | |
break | |
i, j = code_act_by_pos[idx] | |
if i != curr_ex and curr_ex >= 0: | |
# got new example so print the previous one | |
curr_ex_output = prepare_example_print( | |
curr_ex, | |
tokens[curr_ex], | |
tokens_to_color, | |
html, | |
color_fn, | |
) | |
total_examples += 1 | |
if return_example_list: | |
print_output.append((curr_ex_output, len(tokens_to_color))) | |
else: | |
print_output += curr_ex_output | |
curr_ex = i | |
tokens_to_color = [] | |
tokens_to_color.append(j) | |
curr_ex_output = prepare_example_print( | |
curr_ex, | |
tokens[curr_ex], | |
tokens_to_color, | |
html, | |
color_fn, | |
) | |
if return_example_list: | |
print_output.append((curr_ex_output, len(tokens_to_color))) | |
else: | |
print_output += curr_ex_output | |
print_output += color_str("*" * 50, html, "green") | |
total_examples += 1 | |
return print_output | |
def print_token_activations_of_codes( | |
ft_tkns, | |
tokens, | |
is_fsm=False, | |
n=3, | |
start=0, | |
stop=1000, | |
indices=None, | |
max_examples=100, | |
freq_filter=None, | |
randomize=False, | |
html=False, | |
return_example_list=False, | |
): | |
"""Print the tokens for the codebook features.""" | |
indices = list(range(start, stop)) if indices is None else indices | |
num_tokens = len(tokens) * len(tokens[0]) | |
codes, token_act_freqs, token_acts = [], [], [] | |
for i in indices: | |
tkns_of_code = ft_tkns[i] | |
freq = (len(tkns_of_code), 100 * len(tkns_of_code) / num_tokens) | |
if freq_filter is not None and freq[1] > freq_filter: | |
continue | |
codes.append(i) | |
token_act_freqs.append(freq) | |
if len(tkns_of_code) > 0: | |
tkn_acts = print_token_activations_of_code( | |
tkns_of_code, | |
tokens, | |
is_fsm, | |
n=n, | |
max_examples=max_examples, | |
randomize=randomize, | |
html=html, | |
return_example_list=return_example_list, | |
) | |
token_acts.append(tkn_acts) | |
else: | |
token_acts.append("") | |
return codes, token_act_freqs, token_acts | |
def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None): | |
"""Patch in the `code` at `run_cb_ids`.""" | |
pos = slice(None) if pos is None else pos | |
code_pos = slice(None) if code_pos is None else code_pos | |
if code_pos == "append": | |
assert pos == slice(None) | |
run_cb_ids = F.pad(run_cb_ids, (0, 1), mode="constant", value=code) | |
if isinstance(pos, typing.Iterable) or isinstance(pos, typing.Iterable): | |
for p in pos: | |
run_cb_ids[:, p, code_pos] = code | |
else: | |
run_cb_ids[:, pos, code_pos] = code | |
return run_cb_ids | |
def get_cb_hook_key(cb_at: str, layer_idx: int, gcb_idx: Optional[int] = None): | |
"""Get the layer name used to store hooks/cache.""" | |
comp_name = "attn" if "attn" in cb_at else "mlp" | |
if gcb_idx is None: | |
return f"blocks.{layer_idx}.{comp_name}.codebook_layer.hook_codebook_ids" | |
else: | |
return f"blocks.{layer_idx}.{comp_name}.codebook_layer.codebook.{gcb_idx}.hook_codebook_ids" | |
def run_model_fn_with_codes( | |
input, | |
cb_model, | |
fn_name, | |
fn_kwargs=None, | |
list_of_code_infos=(), | |
): | |
"""Run the `cb_model`'s `fn_name` method while activating the codes in `list_of_code_infos`. | |
Common use case includes running the `run_with_cache` method while activating the codes. | |
For running the `generate` method, use `generate_with_codes` instead. | |
""" | |
if fn_kwargs is None: | |
fn_kwargs = {} | |
hook_fns = [ | |
partial(patch_in_codes, pos=tupl.pos, code=tupl.code, code_pos=tupl.code_pos) | |
for tupl in list_of_code_infos | |
] | |
fwd_hooks = [ | |
(get_cb_hook_key(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i]) | |
for i, tupl in enumerate(list_of_code_infos) | |
] | |
cb_model.reset_hook_kwargs() | |
with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model: | |
ret = hooked_model.__getattribute__(fn_name)(input, **fn_kwargs) | |
return ret | |
def generate_with_codes( | |
input, | |
cb_model, | |
list_of_code_infos=(), | |
tokfsm=None, | |
generate_kwargs=None, | |
): | |
"""Sample from the language model while activating the codes in `list_of_code_infos`.""" | |
gen = run_model_fn_with_codes( | |
input, | |
cb_model, | |
"generate", | |
generate_kwargs, | |
list_of_code_infos, | |
) | |
return tokfsm.seq_to_traj(gen) if tokfsm is not None else gen | |
def JSD(logits1, logits2, pos=-1, reduction="batchmean"): | |
"""Compute the Jensen-Shannon divergence between two distributions.""" | |
if len(logits1.shape) == 3: | |
logits1, logits2 = logits1[:, pos, :], logits2[:, pos, :] | |
probs1 = F.softmax(logits1, dim=-1) | |
probs2 = F.softmax(logits2, dim=-1) | |
total_m = (0.5 * (probs1 + probs2)).log() | |
loss = 0.0 | |
loss += F.kl_div( | |
total_m, | |
F.log_softmax(logits1, dim=-1), | |
log_target=True, | |
reduction=reduction, | |
) | |
loss += F.kl_div( | |
total_m, | |
F.log_softmax(logits2, dim=-1), | |
log_target=True, | |
reduction=reduction, | |
) | |
return 0.5 * loss | |
def cb_hook_key_to_info(layer_hook_key: str): | |
"""Get the layer info from the codebook layer hook key. | |
Args: | |
layer_hook_key: the hook key of the codebook layer. | |
E.g. `blocks.3.attn.codebook_layer.hook_codebook_ids` | |
Returns: | |
comp_name: the name of the component codebook is appied at. | |
layer_idx: the layer index. | |
gcb_idx: the codebook index if the codebook layer is grouped, otherwise None. | |
""" | |
layer_search = re.search(r"blocks\.(\d+)\.(\w+)\.", layer_hook_key) | |
assert layer_search is not None | |
layer_idx, comp_name = int(layer_search.group(1)), layer_search.group(2) | |
gcb_idx_search = re.search(r"codebook\.(\d+)", layer_hook_key) | |
if gcb_idx_search is not None: | |
gcb_idx = int(gcb_idx_search.group(1)) | |
else: | |
gcb_idx = None | |
return comp_name, layer_idx, gcb_idx | |
def find_code_changes(cache1, cache2, pos=None): | |
"""Find the codebook codes that are different between the two caches.""" | |
for k in cache1.keys(): | |
if "codebook" in k: | |
c1 = cache1[k][0, pos] | |
c2 = cache2[k][0, pos] | |
if not torch.all(c1 == c2): | |
print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist()) | |
print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist()) | |
def common_codes_in_cache(cache_codes, threshold=0.0): | |
"""Get the common code in the cache.""" | |
codes, counts = torch.unique(cache_codes, return_counts=True, sorted=True) | |
counts = counts.float() * 100 | |
counts /= cache_codes.shape[1] | |
counts, indices = torch.sort(counts, descending=True) | |
codes = codes[indices] | |
indices = counts > threshold | |
codes, counts = codes[indices], counts[indices] | |
return codes, counts | |
def parse_topic_codes_string( | |
info_str: str, | |
pos: Optional[int] = None, | |
code_append: Optional[bool] = False, | |
**code_info_kwargs, | |
): | |
"""Parse the topic codes string.""" | |
code_info_strs = info_str.strip().split("\n") | |
code_info_strs = [e.strip() for e in code_info_strs if e] | |
topic_codes = [] | |
layer, head = None, None | |
if code_append is None: | |
code_pos = None | |
else: | |
code_pos = "append" if code_append else -1 | |
for code_info_str in code_info_strs: | |
topic_codes.append( | |
CodeInfo.from_str( | |
code_info_str, | |
pos=pos, | |
code_pos=code_pos, | |
**code_info_kwargs, | |
) | |
) | |
if code_append is None or code_append: | |
continue | |
if layer == topic_codes[-1].layer and head == topic_codes[-1].head: | |
code_pos -= 1 # type: ignore | |
else: | |
code_pos = -1 | |
topic_codes[-1].code_pos = code_pos | |
layer, head = topic_codes[-1].layer, topic_codes[-1].head | |
return topic_codes | |
def find_similar_codes(cb_model, code_info, n=8): | |
"""Find the `n` most similar codes to the given code using cosine similarity. | |
Useful for finding related codes for interpretability. | |
""" | |
codebook = cb_model.get_codebook(code_info) | |
device = codebook.weight.device | |
code = codebook(torch.tensor(code_info.code).to(device)) | |
code = code.to(device) | |
logits = torch.matmul(code, codebook.weight.T) | |
_, indices = torch.topk(logits, n) | |
assert indices[0] == code_info.code | |
assert torch.allclose(logits[indices[0]], torch.tensor(1.0)) | |
return indices[1:], logits[indices[1:]].tolist() | |