mouadenna commited on
Commit
abcf6b2
β€’
1 Parent(s): 965f72e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -0
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
+ from typing import Dict, List, Tuple
7
+ import plotly.express as px
8
+ from sklearn.decomposition import PCA
9
+ from sklearn.manifold import TSNE
10
+ import plotly.graph_objects as go
11
+
12
+ st.set_page_config(
13
+ page_title="Token & Embedding Visualizer",
14
+ layout="wide"
15
+ )
16
+
17
+ COLORS = {
18
+ 'Special': '#FFB6C1',
19
+ 'Subword': '#98FB98',
20
+ 'Word': '#87CEFA',
21
+ 'Punctuation': '#DDA0DD'
22
+ }
23
+
24
+ @st.cache_resource
25
+ def load_models_and_tokenizers() -> Tuple[Dict, Dict]:
26
+ """Load tokenizers and models with error handling"""
27
+ model_names = {
28
+ "BERT": "bert-base-uncased",
29
+ "RoBERTa": "roberta-base",
30
+ "DistilBERT": "distilbert-base-uncased",
31
+ "MPNet": "microsoft/mpnet-base",
32
+ "DeBERTa": "microsoft/deberta-base",
33
+ }
34
+
35
+ tokenizers = {}
36
+ models = {}
37
+
38
+ for name, model_name in model_names.items():
39
+ try:
40
+ tokenizers[name] = AutoTokenizer.from_pretrained(model_name)
41
+ models[name] = AutoModel.from_pretrained(model_name)
42
+ st.success(f"βœ“ Loaded {name}")
43
+ except Exception as e:
44
+ st.warning(f"Γ— Failed to load {name}: {str(e)}")
45
+
46
+ return tokenizers, models
47
+
48
+ def classify_token(token: str) -> str:
49
+ if token.startswith(('##', '▁', 'Δ ', '_', '.')):
50
+ return 'Subword'
51
+ elif token in ['[CLS]', '[SEP]', '<s>', '</s>', '<pad>', '[PAD]', '[MASK]', '<mask>']:
52
+ return 'Special'
53
+ elif token in [',', '.', '!', '?', ';', ':', '"', "'", '(', ')', '[', ']', '{', '}']:
54
+ return 'Punctuation'
55
+ else:
56
+ return 'Word'
57
+
58
+ @torch.no_grad()
59
+ def get_embeddings(text: str, model, tokenizer) -> Tuple[torch.Tensor, List[str]]:
60
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
61
+ outputs = model(**inputs)
62
+ embeddings = outputs.last_hidden_state[0] # Get first batch
63
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
64
+ return embeddings, tokens
65
+
66
+ def visualize_embeddings(embeddings: torch.Tensor, tokens: List[str], method: str = 'PCA') -> go.Figure:
67
+ embed_array = embeddings.numpy()
68
+
69
+ if method == 'PCA':
70
+ reducer = PCA(n_components=3)
71
+ reduced_embeddings = reducer.fit_transform(embed_array)
72
+ variance_explained = reducer.explained_variance_ratio_
73
+ method_info = f"Total variance explained: {sum(variance_explained):.2%}"
74
+ else: # t-SNE
75
+ reducer = TSNE(n_components=3, random_state=42, perplexity=min(30, len(tokens)-1))
76
+ reduced_embeddings = reducer.fit_transform(embed_array)
77
+ method_info = "t-SNE embedding (perplexity: {})".format(reducer.perplexity)
78
+
79
+ df = pd.DataFrame({
80
+ 'x': reduced_embeddings[:, 0],
81
+ 'y': reduced_embeddings[:, 1],
82
+ 'z': reduced_embeddings[:, 2],
83
+ 'token': tokens,
84
+ 'type': [classify_token(t) for t in tokens]
85
+ })
86
+
87
+ fig = go.Figure()
88
+
89
+ for token_type in df['type'].unique():
90
+ mask = df['type'] == token_type
91
+ fig.add_trace(go.Scatter3d(
92
+ x=df[mask]['x'],
93
+ y=df[mask]['y'],
94
+ z=df[mask]['z'],
95
+ mode='markers+text',
96
+ name=token_type,
97
+ text=df[mask]['token'],
98
+ hovertemplate="Token: %{text}<br>Type: " + token_type + "<extra></extra>",
99
+ marker=dict(
100
+ size=8,
101
+ color=COLORS[token_type],
102
+ opacity=0.8
103
+ )
104
+ ))
105
+
106
+ fig.update_layout(
107
+ title=f"{method} Visualization of Token Embeddings<br><sup>{method_info}</sup>",
108
+ scene=dict(
109
+ xaxis_title=f"{method}_1",
110
+ yaxis_title=f"{method}_2",
111
+ zaxis_title=f"{method}_3"
112
+ ),
113
+ width=800,
114
+ height=800
115
+ )
116
+
117
+ return fig
118
+
119
+ def compute_token_similarities(embeddings: torch.Tensor, tokens: List[str]) -> pd.DataFrame:
120
+ normalized_embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
121
+ similarities = torch.mm(normalized_embeddings, normalized_embeddings.t())
122
+
123
+ sim_df = pd.DataFrame(similarities.numpy(), columns=tokens, index=tokens)
124
+ return sim_df
125
+
126
+ st.title("πŸ”€ Token & Embedding Visualizer")
127
+
128
+ # Load models and tokenizers
129
+ tokenizers, models = load_models_and_tokenizers()
130
+
131
+ token_tab, embedding_tab, similarity_tab = st.tabs([
132
+ "Token Visualization",
133
+ "Embedding Visualization",
134
+ "Token Similarities"
135
+ ])
136
+
137
+ default_text = "Hello world! Let's analyze how neural networks process language. The transformer architecture revolutionized NLP."
138
+ text_input = st.text_area("Enter text to analyze:", value=default_text, height=100)
139
+
140
+ with token_tab:
141
+ st.markdown("""
142
+ Token colors represent:
143
+ - 🟦 Blue: Complete words
144
+ - 🟩 Green: Subwords
145
+ - 🟨 Pink: Special tokens
146
+ - πŸŸͺ Purple: Punctuation
147
+ """)
148
+
149
+ selected_models = st.multiselect(
150
+ "Select models to compare tokens",
151
+ options=list(tokenizers.keys()),
152
+ default=["BERT", "RoBERTa"],
153
+ max_selections=4
154
+ )
155
+
156
+ if text_input and selected_models:
157
+ cols = st.columns(len(selected_models))
158
+
159
+ for idx, model_name in enumerate(selected_models):
160
+ with cols[idx]:
161
+ st.subheader(model_name)
162
+ tokenizer = tokenizers[model_name]
163
+
164
+ tokens = tokenizer.tokenize(text_input)
165
+ token_ids = tokenizer.encode(text_input)
166
+
167
+ if len(tokens) != len(token_ids):
168
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
169
+
170
+ st.metric("Tokens", len(tokens))
171
+
172
+ html_tokens = []
173
+ for token in tokens:
174
+ color = COLORS[classify_token(token)]
175
+ token_text = token.replace('<', '&lt;').replace('>', '&gt;')
176
+ html_tokens.append(
177
+ f'<span style="background-color: {color}; padding: 2px 4px; '
178
+ f'margin: 2px; border-radius: 3px; font-family: monospace;">'
179
+ f'{token_text}</span>'
180
+ )
181
+
182
+ st.markdown(
183
+ '<div style="background-color: white; padding: 10px; '
184
+ 'border-radius: 5px; border: 1px solid #ddd;">'
185
+ f'{"".join(html_tokens)}</div>',
186
+ unsafe_allow_html=True
187
+ )
188
+
189
+ with embedding_tab:
190
+ st.markdown("""
191
+ This tab shows how tokens are embedded in the model's vector space.
192
+ - Compare different dimensionality reduction techniques
193
+ - Observe clustering of similar tokens
194
+ - Explore the relationship between different token types
195
+ """)
196
+
197
+ col1, col2 = st.columns([2, 1])
198
+
199
+ with col1:
200
+ selected_model = st.selectbox(
201
+ "Select model for embedding visualization",
202
+ options=list(models.keys())
203
+ )
204
+
205
+ with col2:
206
+ viz_method = st.radio(
207
+ "Select visualization method",
208
+ options=['PCA', 't-SNE'],
209
+ horizontal=True
210
+ )
211
+
212
+ if text_input and selected_model:
213
+ with st.spinner(f"Generating embeddings with {selected_model}..."):
214
+ embeddings, tokens = get_embeddings(
215
+ text_input,
216
+ models[selected_model],
217
+ tokenizers[selected_model]
218
+ )
219
+
220
+ fig = visualize_embeddings(embeddings, tokens, viz_method)
221
+ st.plotly_chart(fig, use_container_width=True)
222
+
223
+ with st.expander("Embedding Statistics"):
224
+ embed_stats = pd.DataFrame({
225
+ 'Token': tokens,
226
+ 'Type': [classify_token(t) for t in tokens],
227
+ 'Mean': embeddings.mean(dim=1).numpy(),
228
+ 'Std': embeddings.std(dim=1).numpy(),
229
+ 'Norm': torch.norm(embeddings, dim=1).numpy()
230
+ })
231
+ st.dataframe(embed_stats, use_container_width=True)
232
+
233
+ with similarity_tab:
234
+ st.markdown("""
235
+ Explore token similarities based on their embedding representations.
236
+ - Darker colors indicate higher similarity
237
+ - Hover over cells to see exact similarity scores
238
+ """)
239
+
240
+ if text_input and selected_model:
241
+ with st.spinner("Computing token similarities..."):
242
+ sim_df = compute_token_similarities(embeddings, tokens)
243
+
244
+ fig = px.imshow(
245
+ sim_df,
246
+ labels=dict(color="Cosine Similarity"),
247
+ color_continuous_scale="RdYlBu",
248
+ aspect="auto"
249
+ )
250
+ fig.update_layout(
251
+ title="Token Similarity Matrix",
252
+ width=800,
253
+ height=800
254
+ )
255
+ st.plotly_chart(fig, use_container_width=True)
256
+
257
+ st.subheader("Most Similar Token Pairs")
258
+ sim_matrix = sim_df.values
259
+ np.fill_diagonal(sim_matrix, 0) # Exclude self-similarities
260
+ top_k = min(10, len(tokens))
261
+
262
+ pairs = []
263
+ for i in range(len(tokens)):
264
+ for j in range(i+1, len(tokens)):
265
+ pairs.append((tokens[i], tokens[j], sim_matrix[i, j]))
266
+
267
+ top_pairs = sorted(pairs, key=lambda x: x[2], reverse=True)[:top_k]
268
+
269
+ for token1, token2, sim in top_pairs:
270
+ st.write(f"'{token1}' β€” '{token2}': {sim:.3f}")
271
+
272
+ st.markdown("---")
273
+ st.markdown("""
274
+ πŸ’‘ **Tips:**
275
+ - Try comparing how different models tokenize and embed the same text
276
+ - Use PCA for global structure and t-SNE for local relationships
277
+ - Check the similarity matrix for interesting token relationships
278
+ - Experiment with different text types (technical, casual, mixed)
279
+ """)