|
|
|
|
|
|
|
|
|
|
|
|
|
import uuid |
|
from typing import List, Optional, Tuple |
|
|
|
import networkx as nx |
|
import streamlit as st |
|
import torch |
|
import transformers |
|
|
|
import llm_transparency_tool.routes.graph |
|
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm |
|
from llm_transparency_tool.models.transparent_llm import TransparentLlm |
|
|
|
GPU = "gpu" |
|
CPU = "cpu" |
|
|
|
|
|
|
|
B0 = 0 |
|
|
|
|
|
def possible_devices() -> List[str]: |
|
devices = [] |
|
if torch.cuda.is_available(): |
|
devices.append("gpu") |
|
devices.append("cpu") |
|
return devices |
|
|
|
|
|
def load_dataset(filename) -> List[str]: |
|
with open(filename) as f: |
|
dataset = [s.strip("\n") for s in f.readlines()] |
|
print(f"Loaded {len(dataset)} sentences from {filename}") |
|
return dataset |
|
|
|
|
|
@st.cache_resource( |
|
hash_funcs={ |
|
TransformerLensTransparentLlm: id |
|
} |
|
) |
|
def load_model( |
|
model_name: str, |
|
_device: str, |
|
_model_path: Optional[str] = None, |
|
_dtype: torch.dtype = torch.float32, |
|
) -> TransparentLlm: |
|
""" |
|
Returns the loaded model along with its key. The key is just a unique string which |
|
can be used later to identify if the model has changed. |
|
""" |
|
assert _device in possible_devices() |
|
|
|
causal_lm = None |
|
tokenizer = None |
|
|
|
tl_lm = TransformerLensTransparentLlm( |
|
model_name=model_name, |
|
hf_model=causal_lm, |
|
tokenizer=tokenizer, |
|
device=_device, |
|
dtype=_dtype, |
|
) |
|
|
|
return tl_lm |
|
|
|
|
|
def run_model(model: TransparentLlm, sentence: str) -> None: |
|
print(f"Running inference for '{sentence}'") |
|
model.run([sentence]) |
|
|
|
|
|
def load_model_with_session_caching( |
|
**kwargs, |
|
) -> Tuple[TransparentLlm, str]: |
|
return load_model(**kwargs) |
|
|
|
def run_model_with_session_caching( |
|
_model: TransparentLlm, |
|
model_key: str, |
|
sentence: str, |
|
): |
|
LAST_RUN_MODEL_KEY = "last_run_model_key" |
|
LAST_RUN_SENTENCE = "last_run_sentence" |
|
state = st.session_state |
|
|
|
if ( |
|
state.get(LAST_RUN_MODEL_KEY, None) == model_key |
|
and state.get(LAST_RUN_SENTENCE, None) == sentence |
|
): |
|
return |
|
|
|
run_model(_model, sentence) |
|
state[LAST_RUN_MODEL_KEY] = model_key |
|
state[LAST_RUN_SENTENCE] = sentence |
|
|
|
|
|
@st.cache_resource( |
|
hash_funcs={ |
|
TransformerLensTransparentLlm: id |
|
} |
|
) |
|
def get_contribution_graph( |
|
model: TransparentLlm, |
|
model_key: str, |
|
tokens: List[str], |
|
threshold: float, |
|
) -> nx.Graph: |
|
""" |
|
The `model_key` and `tokens` are used only for caching. The model itself is not |
|
hashed, hence the `_` in the beginning. |
|
""" |
|
return llm_transparency_tool.routes.graph.build_full_graph( |
|
model, |
|
B0, |
|
threshold, |
|
) |
|
|
|
|
|
def st_placeholder( |
|
text: str, |
|
container=st, |
|
border: bool = True, |
|
height: Optional[int] = 500, |
|
): |
|
empty = container.empty() |
|
empty.container(border=border, height=height).write(f'<small>{text}</small>', unsafe_allow_html=True) |
|
return empty |
|
|