|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import networkx as nx |
|
import pandas as pd |
|
import plotly.express |
|
import plotly.graph_objects as go |
|
import streamlit as st |
|
import streamlit_extras.row as st_row |
|
import torch |
|
from jaxtyping import Float |
|
from torch.amp import autocast |
|
from transformers import HfArgumentParser |
|
|
|
import llm_transparency_tool.components |
|
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm |
|
import llm_transparency_tool.routes.contributions as contributions |
|
import llm_transparency_tool.routes.graph |
|
from llm_transparency_tool.models.transparent_llm import TransparentLlm |
|
from llm_transparency_tool.routes.graph_node import NodeType |
|
from llm_transparency_tool.server.graph_selection import ( |
|
GraphSelection, |
|
UiGraphEdge, |
|
UiGraphNode, |
|
) |
|
from llm_transparency_tool.server.styles import ( |
|
RenderSettings, |
|
logits_color_map, |
|
margins_css, |
|
string_to_display, |
|
) |
|
from llm_transparency_tool.server.utils import ( |
|
B0, |
|
get_contribution_graph, |
|
load_dataset, |
|
load_model, |
|
possible_devices, |
|
run_model_with_session_caching, |
|
st_placeholder, |
|
) |
|
from llm_transparency_tool.server.monitor import SystemMonitor |
|
|
|
from networkx.classes.digraph import DiGraph |
|
|
|
|
|
@st.cache_resource( |
|
hash_funcs={ |
|
nx.Graph: id, |
|
DiGraph: id |
|
} |
|
) |
|
def cached_build_paths_to_predictions( |
|
graph: nx.Graph, |
|
n_layers: int, |
|
n_tokens: int, |
|
starting_tokens: List[int], |
|
threshold: float, |
|
): |
|
return llm_transparency_tool.routes.graph.build_paths_to_predictions( |
|
graph, n_layers, n_tokens, starting_tokens, threshold |
|
) |
|
|
|
@st.cache_resource( |
|
hash_funcs={ |
|
TransformerLensTransparentLlm: id |
|
} |
|
) |
|
def cached_run_inference_and_populate_state( |
|
stateless_model, |
|
sentences, |
|
): |
|
stateful_model = stateless_model.copy() |
|
stateful_model.run(sentences) |
|
return stateful_model |
|
|
|
|
|
@dataclass |
|
class LlmViewerConfig: |
|
debug: bool = field( |
|
default=False, |
|
metadata={"help": "Show debugging information, like the time profile."}, |
|
) |
|
|
|
preloaded_dataset_filename: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "The name of the text file to load the lines from."}, |
|
) |
|
|
|
demo_mode: bool = field( |
|
default=False, |
|
metadata={"help": "Whether the app should be in the demo mode."}, |
|
) |
|
|
|
allow_loading_dataset_files: bool = field( |
|
default=True, |
|
metadata={"help": "Whether the app should be able to load the dataset files " "on the server side."}, |
|
) |
|
|
|
max_user_string_length: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Limit for the length of user-provided sentences (in characters), " "or None if there is no limit." |
|
}, |
|
) |
|
|
|
models: Dict[str, str] = field( |
|
default_factory=dict, |
|
metadata={ |
|
"help": "Locations of models which are stored locally. Dictionary: official " |
|
"HuggingFace name -> path to dir. If None is specified, the model will be" |
|
"downloaded from HuggingFace." |
|
}, |
|
) |
|
|
|
default_model: str = field( |
|
default="", |
|
metadata={"help": "The model to load once the UI is started."}, |
|
) |
|
|
|
|
|
class App: |
|
_stateful_model: TransparentLlm = None |
|
render_settings = RenderSettings() |
|
_graph: Optional[nx.Graph] = None |
|
_contribution_threshold: float = 0.0 |
|
_renormalize_after_threshold: bool = False |
|
_normalize_before_unembedding: bool = True |
|
|
|
@property |
|
def stateful_model(self) -> TransparentLlm: |
|
return self._stateful_model |
|
|
|
def __init__(self, config: LlmViewerConfig): |
|
self._config = config |
|
st.set_page_config(layout="wide") |
|
st.markdown(margins_css, unsafe_allow_html=True) |
|
|
|
def _get_representation(self, node: Optional[UiGraphNode]) -> Optional[Float[torch.Tensor, "d_model"]]: |
|
if node is None: |
|
return None |
|
fn = { |
|
NodeType.AFTER_ATTN: self.stateful_model.residual_after_attn, |
|
NodeType.AFTER_FFN: self.stateful_model.residual_out, |
|
NodeType.FFN: None, |
|
NodeType.ORIGINAL: self.stateful_model.residual_in, |
|
} |
|
return fn[node.type](node.layer)[B0][node.token] |
|
|
|
def draw_model_info(self): |
|
info = self.stateful_model.model_info().__dict__ |
|
df = pd.DataFrame( |
|
data=[str(x) for x in info.values()], |
|
index=info.keys(), |
|
columns=["Model parameter"], |
|
) |
|
st.dataframe(df, use_container_width=False) |
|
|
|
def draw_dataset_selection(self) -> int: |
|
def update_dataset(filename: Optional[str]): |
|
dataset = load_dataset(filename) if filename is not None else [] |
|
st.session_state["dataset"] = dataset |
|
st.session_state["dataset_file"] = filename |
|
|
|
if "dataset" not in st.session_state: |
|
update_dataset(self._config.preloaded_dataset_filename) |
|
|
|
|
|
if not self._config.demo_mode: |
|
if self._config.allow_loading_dataset_files: |
|
row_f = st_row.row([2, 1], vertical_align="bottom") |
|
filename = row_f.text_input("Dataset", value=st.session_state.dataset_file or "") |
|
if row_f.button("Load"): |
|
update_dataset(filename) |
|
row_s = st_row.row([2, 1], vertical_align="bottom") |
|
new_sentence = row_s.text_input("New sentence") |
|
new_sentence_added = False |
|
|
|
if row_s.button("Add"): |
|
max_len = self._config.max_user_string_length |
|
n = len(new_sentence) |
|
if max_len is None or n <= max_len: |
|
st.session_state.dataset.append(new_sentence) |
|
new_sentence_added = True |
|
st.session_state.sentence_selector = new_sentence |
|
else: |
|
st.warning(f"Sentence length {n} is larger than " f"the configured limit of {max_len}") |
|
|
|
sentences = st.session_state.dataset |
|
selection = st.selectbox( |
|
"Sentence", |
|
sentences, |
|
index=len(sentences) - 1, |
|
key="sentence_selector", |
|
) |
|
return selection |
|
|
|
def _unembed( |
|
self, |
|
representation: torch.Tensor, |
|
) -> torch.Tensor: |
|
return self.stateful_model.unembed(representation, normalize=self._normalize_before_unembedding) |
|
|
|
def draw_graph(self, contribution_threshold: float) -> Optional[GraphSelection]: |
|
tokens = self.stateful_model.tokens()[B0] |
|
n_tokens = tokens.shape[0] |
|
model_info = self.stateful_model.model_info() |
|
|
|
graphs = cached_build_paths_to_predictions( |
|
self._graph, |
|
model_info.n_layers, |
|
n_tokens, |
|
range(n_tokens), |
|
contribution_threshold, |
|
) |
|
|
|
return llm_transparency_tool.components.contribution_graph( |
|
model_info, |
|
self.stateful_model.tokens_to_strings(tokens), |
|
graphs, |
|
key=f"graph_{hash(self.sentence)}", |
|
) |
|
|
|
def draw_token_matrix( |
|
self, |
|
values: Float[torch.Tensor, "t t"], |
|
tokens: List[str], |
|
value_name: str, |
|
title: str, |
|
): |
|
assert values.shape[0] == len(tokens) |
|
labels = { |
|
"x": "<b>src</b>", |
|
"y": "<b>tgt</b>", |
|
"color": value_name, |
|
} |
|
|
|
captions = [f"({i}){t}" for i, t in enumerate(tokens)] |
|
|
|
fig = plotly.express.imshow( |
|
values.cpu(), |
|
title=f'<b>{title}</b>', |
|
labels=labels, |
|
x=captions, |
|
y=captions, |
|
color_continuous_scale=self.render_settings.attention_color_map, |
|
aspect="equal", |
|
) |
|
fig.update_layout( |
|
autosize=True, |
|
margin=go.layout.Margin( |
|
l=50, |
|
r=0, |
|
b=100, |
|
t=100, |
|
|
|
) |
|
) |
|
fig.update_xaxes(tickmode="linear") |
|
fig.update_yaxes(tickmode="linear") |
|
fig.update_coloraxes(showscale=False) |
|
|
|
st.plotly_chart(fig, use_container_width=True, theme=None) |
|
|
|
def draw_attn_info(self, edge: UiGraphEdge, container_attention_map) -> Optional[int]: |
|
""" |
|
Returns: the index of the selected head. |
|
""" |
|
|
|
n_heads = self.stateful_model.model_info().n_heads |
|
|
|
layer = edge.target.layer |
|
|
|
head_contrib, _ = contributions.get_attention_contributions( |
|
resid_pre=self.stateful_model.residual_in(layer)[B0].unsqueeze(0), |
|
resid_mid=self.stateful_model.residual_after_attn(layer)[B0].unsqueeze(0), |
|
decomposed_attn=self.stateful_model.decomposed_attn(B0, layer).unsqueeze(0), |
|
) |
|
|
|
|
|
flat_contrib = head_contrib[0, edge.target.token, edge.source.token, :] |
|
assert flat_contrib.shape[0] == n_heads, f"{flat_contrib.shape} vs {n_heads}" |
|
|
|
selected_head = llm_transparency_tool.components.selector( |
|
items=[f"H{h}" if h >= 0 else "All" for h in range(-1, n_heads)], |
|
indices=range(-1, n_heads), |
|
temperatures=[sum(flat_contrib).item()] + flat_contrib.tolist(), |
|
preselected_index=flat_contrib.argmax().item(), |
|
key=f"head_selector_layer_{layer}" |
|
) |
|
print(f"head_selector_layer_{layer}_from_tok_{edge.source.token}_to_tok_{edge.target.token}") |
|
if selected_head == -1 or selected_head is None: |
|
|
|
selected_head = flat_contrib.argmax().item() |
|
print('****\n' * 3 + f"selected_head: {selected_head}" + '\n****\n' * 3) |
|
|
|
|
|
if selected_head is not None: |
|
tokens = [ |
|
string_to_display(s) for s in self.stateful_model.tokens_to_strings(self.stateful_model.tokens()[B0]) |
|
] |
|
|
|
with container_attention_map: |
|
attn_container, contrib_container = st.columns([1, 1]) |
|
with attn_container: |
|
attn = self.stateful_model.attention_matrix(B0, layer, selected_head) |
|
self.draw_token_matrix( |
|
attn, |
|
tokens, |
|
"attention", |
|
f"Attention map L{layer} H{selected_head}", |
|
) |
|
with contrib_container: |
|
contrib = head_contrib[B0, :, :, selected_head] |
|
self.draw_token_matrix( |
|
contrib, |
|
tokens, |
|
"contribution", |
|
f"Contribution map L{layer} H{selected_head}", |
|
) |
|
|
|
return selected_head |
|
|
|
def draw_ffn_info(self, node: UiGraphNode) -> Optional[int]: |
|
""" |
|
Returns: the index of the selected neuron. |
|
""" |
|
|
|
resid_mid = self.stateful_model.residual_after_attn(node.layer)[B0][node.token] |
|
resid_post = self.stateful_model.residual_out(node.layer)[B0][node.token] |
|
decomposed_ffn = self.stateful_model.decomposed_ffn_out(B0, node.layer, node.token) |
|
c_ffn, _ = contributions.get_decomposed_mlp_contributions(resid_mid, resid_post, decomposed_ffn) |
|
|
|
top_values, top_i = c_ffn.sort(descending=True) |
|
n = min(self.render_settings.n_top_neurons, c_ffn.shape[0]) |
|
top_neurons = top_i[0:n].tolist() |
|
|
|
selected_neuron = llm_transparency_tool.components.selector( |
|
items=[f"{top_neurons[i]}" if i >= 0 else "All" for i in range(-1, n)], |
|
indices=range(-1, n), |
|
temperatures=[0.0] + top_values[0:n].tolist(), |
|
preselected_index=-1, |
|
key="neuron_selector", |
|
) |
|
if selected_neuron is None: |
|
selected_neuron = -1 |
|
selected_neuron = None if selected_neuron == -1 else top_neurons[selected_neuron] |
|
|
|
return selected_neuron |
|
|
|
def _draw_token_table( |
|
self, |
|
n_top: int, |
|
n_bottom: int, |
|
representation: torch.Tensor, |
|
predecessor: Optional[torch.Tensor] = None, |
|
): |
|
n_total = n_top + n_bottom |
|
|
|
logits = self._unembed(representation) |
|
n_vocab = logits.shape[0] |
|
scores, indices = torch.topk(logits, n_top, largest=True) |
|
positions = list(range(n_top)) |
|
|
|
if n_bottom > 0: |
|
low_scores, low_indices = torch.topk(logits, n_bottom, largest=False) |
|
indices = torch.cat((indices, low_indices.flip(0))) |
|
scores = torch.cat((scores, low_scores.flip(0))) |
|
positions += range(n_vocab - n_bottom, n_vocab) |
|
|
|
tokens = [string_to_display(w) for w in self.stateful_model.tokens_to_strings(indices)] |
|
|
|
if predecessor is not None: |
|
pre_logits = self._unembed(predecessor) |
|
_, sorted_pre_indices = pre_logits.sort(descending=True) |
|
pre_indices_dict = {index: pos for pos, index in enumerate(sorted_pre_indices.tolist())} |
|
old_positions = [pre_indices_dict[i] for i in indices.tolist()] |
|
|
|
def pos_gain_string(pos, old_pos): |
|
if pos == old_pos: |
|
return "" |
|
sign = "β" if pos > old_pos else "β" |
|
return f"({sign}{abs(pos - old_pos)})" |
|
|
|
position_strings = [f"{i} {pos_gain_string(i, old_i)}" for (i, old_i) in zip(positions, old_positions)] |
|
else: |
|
position_strings = [str(pos) for pos in positions] |
|
|
|
def pos_gain_color(s): |
|
color = "black" |
|
if isinstance(s, str): |
|
if "β" in s: |
|
color = "red" |
|
if "β" in s: |
|
color = "green" |
|
return f"color: {color}" |
|
|
|
top_df = pd.DataFrame( |
|
data=zip(position_strings, tokens, scores.tolist()), |
|
columns=["Pos", "Token", "Score"], |
|
) |
|
|
|
st.dataframe( |
|
top_df.style.map(pos_gain_color) |
|
.background_gradient( |
|
axis=0, |
|
cmap=logits_color_map(positive_and_negative=n_bottom > 0), |
|
) |
|
.format(precision=3), |
|
hide_index=True, |
|
height=self.render_settings.table_cell_height * (n_total + 1), |
|
use_container_width=True, |
|
) |
|
|
|
def draw_token_dynamics(self, representation: torch.Tensor, block_name: str) -> None: |
|
st.caption(block_name) |
|
self._draw_token_table( |
|
self.render_settings.n_promoted_tokens, |
|
self.render_settings.n_suppressed_tokens, |
|
representation, |
|
None, |
|
) |
|
|
|
def draw_top_tokens( |
|
self, |
|
node: UiGraphNode, |
|
container_top_tokens, |
|
container_token_dynamics, |
|
) -> None: |
|
pre_node = node.get_residual_predecessor() |
|
if pre_node is None: |
|
return |
|
|
|
representation = self._get_representation(node) |
|
predecessor = self._get_representation(pre_node) |
|
|
|
with container_top_tokens: |
|
st.caption(node.get_name()) |
|
self._draw_token_table( |
|
self.render_settings.n_top_tokens, |
|
0, |
|
representation, |
|
predecessor, |
|
) |
|
if container_token_dynamics is not None: |
|
with container_token_dynamics: |
|
self.draw_token_dynamics(representation - predecessor, node.get_predecessor_block_name()) |
|
|
|
def draw_attention_dynamics(self, node: UiGraphNode, head: Optional[int]): |
|
block_name = node.get_head_name(head) |
|
block_output = ( |
|
self.stateful_model.attention_output_per_head(B0, node.layer, node.token, head) |
|
if head is not None |
|
else self.stateful_model.attention_output(B0, node.layer, node.token) |
|
) |
|
self.draw_token_dynamics(block_output, block_name) |
|
|
|
def draw_ffn_dynamics(self, node: UiGraphNode, neuron: Optional[int]): |
|
block_name = node.get_neuron_name(neuron) |
|
block_output = ( |
|
self.stateful_model.neuron_output(node.layer, neuron) |
|
if neuron is not None |
|
else self.stateful_model.ffn_out(node.layer)[B0][node.token] |
|
) |
|
self.draw_token_dynamics(block_output, block_name) |
|
|
|
def draw_precision_controls(self, device: str) -> Tuple[torch.dtype, bool]: |
|
""" |
|
Draw fp16/fp32 switch and AMP control. |
|
|
|
return: The selected precision and whether AMP should be enabled. |
|
""" |
|
|
|
if device == "cpu": |
|
dtype = torch.float32 |
|
else: |
|
dtype = st.selectbox( |
|
"Precision", |
|
[torch.float16, torch.bfloat16, torch.float32], |
|
index=0, |
|
) |
|
|
|
amp_enabled = dtype != torch.float32 |
|
|
|
return dtype, amp_enabled |
|
|
|
def draw_controls(self): |
|
|
|
with st.sidebar.expander("Model", expanded=True): |
|
list_of_devices = possible_devices() |
|
if len(list_of_devices) > 1: |
|
self.device = st.selectbox( |
|
"Device", |
|
possible_devices(), |
|
index=0, |
|
) |
|
else: |
|
self.device = list_of_devices[0] |
|
|
|
self.dtype, self.amp_enabled = self.draw_precision_controls(self.device) |
|
|
|
model_list = list(self._config.models) |
|
default_choice = model_list.index(self._config.default_model) |
|
|
|
self.model_name = st.selectbox( |
|
"Model", |
|
model_list, |
|
index=default_choice, |
|
) |
|
|
|
if self.model_name: |
|
self._stateful_model = load_model( |
|
model_name=self.model_name, |
|
_model_path=self._config.models[self.model_name], |
|
_device=self.device, |
|
_dtype=self.dtype, |
|
) |
|
self.model_key = self.model_name |
|
self.draw_model_info() |
|
|
|
self.sentence = self.draw_dataset_selection() |
|
|
|
with st.sidebar.expander("Graph", expanded=True): |
|
self._contribution_threshold = st.slider( |
|
min_value=0.01, |
|
max_value=0.1, |
|
step=0.01, |
|
value=0.04, |
|
format=r"%.3f", |
|
label="Contribution threshold", |
|
) |
|
self._renormalize_after_threshold = st.checkbox("Renormalize after threshold", value=True) |
|
self._normalize_before_unembedding = st.checkbox("Normalize before unembedding", value=True) |
|
|
|
def run_inference(self): |
|
|
|
with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype): |
|
self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence]) |
|
|
|
with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype): |
|
self._graph = get_contribution_graph( |
|
self.stateful_model, |
|
self.model_key, |
|
self.stateful_model.tokens()[B0].tolist(), |
|
(self._contribution_threshold if self._renormalize_after_threshold else 0.0), |
|
) |
|
|
|
def draw_graph_and_selection( |
|
self, |
|
) -> None: |
|
( |
|
container_graph, |
|
container_tokens, |
|
) = st.columns(self.render_settings.column_proportions) |
|
|
|
container_graph_left, container_graph_right = container_graph.columns([5, 1]) |
|
|
|
container_graph_left.write('##### Graph') |
|
heads_placeholder = container_graph_right.empty() |
|
heads_placeholder.write('##### Blocks') |
|
container_graph_right_used = False |
|
|
|
container_top_tokens, container_token_dynamics = container_tokens.columns([1, 1]) |
|
container_top_tokens.write('##### Top Tokens') |
|
container_top_tokens_used = False |
|
container_token_dynamics.write('##### Promoted Tokens') |
|
container_token_dynamics_used = False |
|
|
|
try: |
|
|
|
if self.sentence is None: |
|
return |
|
|
|
with container_graph_left: |
|
selection = self.draw_graph(self._contribution_threshold if not self._renormalize_after_threshold else 0.0) |
|
|
|
if selection is None: |
|
return |
|
|
|
node = selection.node |
|
edge = selection.edge |
|
|
|
if edge is not None and edge.target.type == NodeType.AFTER_ATTN: |
|
with container_graph_right: |
|
container_graph_right_used = True |
|
heads_placeholder.write('##### Heads') |
|
head = self.draw_attn_info(edge, container_graph) |
|
with container_token_dynamics: |
|
self.draw_attention_dynamics(edge.target, head) |
|
container_token_dynamics_used = True |
|
elif node is not None and node.type == NodeType.FFN: |
|
with container_graph_right: |
|
container_graph_right_used = True |
|
heads_placeholder.write('##### Neurons') |
|
neuron = self.draw_ffn_info(node) |
|
with container_token_dynamics: |
|
self.draw_ffn_dynamics(node, neuron) |
|
container_token_dynamics_used = True |
|
|
|
if node is not None and node.is_in_residual_stream(): |
|
self.draw_top_tokens( |
|
node, |
|
container_top_tokens, |
|
container_token_dynamics if not container_token_dynamics_used else None, |
|
) |
|
container_top_tokens_used = True |
|
container_token_dynamics_used = True |
|
finally: |
|
if not container_graph_right_used: |
|
st_placeholder('Click on an edge to see head contributions. \n\n' |
|
'Or click on FFN to see individual neuron contributions.', container_graph_right, height=1100) |
|
if not container_top_tokens_used: |
|
st_placeholder('Select a node from residual stream to see its top tokens.', container_top_tokens, height=1100) |
|
if not container_token_dynamics_used: |
|
st_placeholder('Select a node to see its promoted tokens.', container_token_dynamics, height=1100) |
|
|
|
|
|
def run(self): |
|
|
|
with st.sidebar.expander("About", expanded=True): |
|
if self._config.demo_mode: |
|
st.caption(""" |
|
The app is deployed in Demo Mode, thus only predefined models and inputs are available.\n |
|
You can still install the app locally and use your own models and inputs.\n |
|
See https://github.com/facebookresearch/llm-transparency-tool for more information. |
|
""") |
|
|
|
self.draw_controls() |
|
|
|
if not self.model_name: |
|
st.warning("No model selected") |
|
st.stop() |
|
|
|
if self.sentence is None: |
|
st.warning("No sentence selected") |
|
else: |
|
with torch.inference_mode(): |
|
self.run_inference() |
|
|
|
self.draw_graph_and_selection() |
|
|
|
|
|
if __name__ == "__main__": |
|
top_parser = argparse.ArgumentParser() |
|
top_parser.add_argument("config_file") |
|
args = top_parser.parse_args() |
|
|
|
parser = HfArgumentParser([LlmViewerConfig]) |
|
config = parser.parse_json_file(args.config_file)[0] |
|
|
|
with SystemMonitor(config.debug) as prof: |
|
app = App(config) |
|
app.run() |
|
|