Spaces:
Sleeping
Sleeping
Update src/subpages/hidden_states.py
Browse files- src/subpages/hidden_states.py +39 -89
src/subpages/hidden_states.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
|
3 |
-
"""
|
4 |
import numpy as np
|
5 |
import plotly.express as px
|
6 |
import plotly.graph_objects as go
|
@@ -9,79 +6,24 @@ import streamlit as st
|
|
9 |
from src.subpages.page import Context, Page
|
10 |
|
11 |
|
12 |
-
|
13 |
-
def
|
14 |
-
|
|
|
15 |
|
16 |
-
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
from sklearn.decomposition import TruncatedSVD
|
27 |
|
28 |
-
|
29 |
-
return svd.fit_transform(X)
|
30 |
-
|
31 |
-
|
32 |
-
@st.cache
|
33 |
-
def reduce_dim_pca(X, random_state=42):
|
34 |
-
"""Principal component analysis (PCA).
|
35 |
-
|
36 |
-
Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
X: Training data
|
40 |
-
random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.
|
41 |
-
|
42 |
-
Returns:
|
43 |
-
ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
|
44 |
-
"""
|
45 |
-
from sklearn.decomposition import PCA
|
46 |
-
|
47 |
-
return PCA(n_components=2, random_state=random_state).fit_transform(X)
|
48 |
-
|
49 |
-
|
50 |
-
@st.cache
|
51 |
-
def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
|
52 |
-
"""Uniform Manifold Approximation and Projection
|
53 |
-
|
54 |
-
Finds a low dimensional embedding of the data that approximates an underlying manifold.
|
55 |
-
|
56 |
-
Args:
|
57 |
-
X: Training data
|
58 |
-
n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
|
59 |
-
min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
|
60 |
-
metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".
|
61 |
-
|
62 |
-
Returns:
|
63 |
-
ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
|
64 |
-
"""
|
65 |
-
from umap import UMAP
|
66 |
-
|
67 |
-
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
|
68 |
-
|
69 |
-
|
70 |
-
class HiddenStatesPage(Page):
|
71 |
-
name = "Hidden States"
|
72 |
-
icon = "grid-3x3"
|
73 |
-
|
74 |
-
def _get_widget_defaults(self):
|
75 |
-
return {
|
76 |
-
"n_tokens": 1_000,
|
77 |
-
"svd_n_iter": 5,
|
78 |
-
"svd_random_state": 42,
|
79 |
-
"umap_n_neighbors": 15,
|
80 |
-
"umap_metric": "euclidean",
|
81 |
-
"umap_min_dist": 0.1,
|
82 |
-
}
|
83 |
-
|
84 |
-
def render(self, context: Context):
|
85 |
st.title("Embeddings")
|
86 |
|
87 |
with st.expander("💡", expanded=True):
|
@@ -90,7 +32,6 @@ class HiddenStatesPage(Page):
|
|
90 |
)
|
91 |
|
92 |
col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
|
93 |
-
df = context.df_tokens_merged.copy()
|
94 |
dim_algo = "SVD"
|
95 |
n_tokens = 100
|
96 |
|
@@ -100,7 +41,7 @@ class HiddenStatesPage(Page):
|
|
100 |
"#tokens",
|
101 |
key="n_tokens",
|
102 |
min_value=100,
|
103 |
-
max_value=len(df["tokens"].unique()),
|
104 |
step=100,
|
105 |
)
|
106 |
|
@@ -131,30 +72,30 @@ class HiddenStatesPage(Page):
|
|
131 |
pass
|
132 |
|
133 |
with col2:
|
134 |
-
sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
|
135 |
|
136 |
-
X = np.array(df["hidden_states"].tolist())
|
137 |
transformed_hidden_states = None
|
138 |
if dim_algo == "SVD":
|
139 |
-
transformed_hidden_states =
|
140 |
elif dim_algo == "PCA":
|
141 |
-
transformed_hidden_states =
|
142 |
elif dim_algo == "UMAP":
|
143 |
-
transformed_hidden_states =
|
144 |
X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
|
145 |
)
|
146 |
|
147 |
assert isinstance(transformed_hidden_states, np.ndarray)
|
148 |
-
df["x"] = transformed_hidden_states[:, 0]
|
149 |
-
df["y"] = transformed_hidden_states[:, 1]
|
150 |
-
df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
|
151 |
-
df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
|
152 |
-
df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
|
153 |
-
df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
|
154 |
-
df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
|
155 |
-
df["disagreements"] = df["labels"] != df["preds"]
|
156 |
-
|
157 |
-
subset = df[:n_tokens]
|
158 |
disagreements_trace = go.Scatter(
|
159 |
x=subset[subset["disagreements"]]["x"],
|
160 |
y=subset[subset["disagreements"]]["y"],
|
@@ -192,3 +133,12 @@ class HiddenStatesPage(Page):
|
|
192 |
)
|
193 |
fig.add_trace(disagreements_trace)
|
194 |
st.plotly_chart(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import plotly.express as px
|
3 |
import plotly.graph_objects as go
|
|
|
6 |
from src.subpages.page import Context, Page
|
7 |
|
8 |
|
9 |
+
class HiddenStatesVisualizer:
|
10 |
+
def __init__(self, context: Context):
|
11 |
+
self.context = context
|
12 |
+
self.df = context.df_tokens_merged.copy()
|
13 |
|
14 |
+
def _reduce_dim_svd(self, X, n_iter: int, random_state=42):
|
15 |
+
# Implement your SVD reduction here
|
16 |
+
pass
|
17 |
|
18 |
+
def _reduce_dim_pca(self, X, random_state=42):
|
19 |
+
# Implement your PCA reduction here
|
20 |
+
pass
|
|
|
21 |
|
22 |
+
def _reduce_dim_umap(self, X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
|
23 |
+
# Implement your UMAP reduction here
|
24 |
+
pass
|
|
|
25 |
|
26 |
+
def visualize_hidden_states(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
st.title("Embeddings")
|
28 |
|
29 |
with st.expander("💡", expanded=True):
|
|
|
32 |
)
|
33 |
|
34 |
col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
|
|
|
35 |
dim_algo = "SVD"
|
36 |
n_tokens = 100
|
37 |
|
|
|
41 |
"#tokens",
|
42 |
key="n_tokens",
|
43 |
min_value=100,
|
44 |
+
max_value=len(self.df["tokens"].unique()),
|
45 |
step=100,
|
46 |
)
|
47 |
|
|
|
72 |
pass
|
73 |
|
74 |
with col2:
|
75 |
+
sents = self.df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
|
76 |
|
77 |
+
X = np.array(self.df["hidden_states"].tolist())
|
78 |
transformed_hidden_states = None
|
79 |
if dim_algo == "SVD":
|
80 |
+
transformed_hidden_states = self._reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
|
81 |
elif dim_algo == "PCA":
|
82 |
+
transformed_hidden_states = self._reduce_dim_pca(X)
|
83 |
elif dim_algo == "UMAP":
|
84 |
+
transformed_hidden_states = self._reduce_dim_umap(
|
85 |
X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
|
86 |
)
|
87 |
|
88 |
assert isinstance(transformed_hidden_states, np.ndarray)
|
89 |
+
self.df["x"] = transformed_hidden_states[:, 0]
|
90 |
+
self.df["y"] = transformed_hidden_states[:, 1]
|
91 |
+
self.df["sent0"] = self.df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
|
92 |
+
self.df["sent1"] = self.df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
|
93 |
+
self.df["sent2"] = self.df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
|
94 |
+
self.df["sent3"] = self.df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
|
95 |
+
self.df["sent4"] = self.df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
|
96 |
+
self.df["disagreements"] = self.df["labels"] != self.df["preds"]
|
97 |
+
|
98 |
+
subset = self.df[:n_tokens]
|
99 |
disagreements_trace = go.Scatter(
|
100 |
x=subset[subset["disagreements"]]["x"],
|
101 |
y=subset[subset["disagreements"]]["y"],
|
|
|
133 |
)
|
134 |
fig.add_trace(disagreements_trace)
|
135 |
st.plotly_chart(fig)
|
136 |
+
|
137 |
+
|
138 |
+
class HiddenStatesPage(Page):
|
139 |
+
name = "Hidden States"
|
140 |
+
icon = "grid-3x3"
|
141 |
+
|
142 |
+
def render(self, context: Context):
|
143 |
+
visualizer = HiddenStatesVisualizer(context)
|
144 |
+
visualizer.visualize_hidden_states()
|