s-a-malik commited on
Commit
bc61ed1
1 Parent(s): bf84689

sentence level highlighting, remove acc probe for now

Browse files
Files changed (3) hide show
  1. __pycache__/app.cpython-311.pyc +0 -0
  2. app.py +108 -56
  3. debug.ipynb +171 -18
__pycache__/app.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import pickle as pkl
3
  from pathlib import Path
4
  from threading import Thread
5
- from typing import List, Tuple, Iterator, Optional
6
  from queue import Queue
7
 
8
  import spaces
@@ -10,7 +10,7 @@ import gradio as gr
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
13
- # TODO Sentence level highlighting instead (prediction after every word is not what it was trained on). Also solves token-level highlighting issues.
14
  # TODO log prob output scaling highlighting instead?
15
  # TODO make it look nicer
16
  # TODO better examples.
@@ -18,8 +18,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
18
  # TODO add options to switch between models, SLT/TBG, layers?
19
  # TODO full semantic entropy calculation
20
 
21
- MAX_MAX_NEW_TOKENS = 2048
22
- DEFAULT_MAX_NEW_TOKENS = 1024
23
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
24
 
25
  DESCRIPTION = """
@@ -31,11 +31,6 @@ DESCRIPTION = """
31
  <li><span style="background-color: #00FF00; color: black">Green</span> indicates more certain generations</li>
32
  <li><span style="background-color: #FF0000; color: black">Red</span> indicates more uncertain generations</li>
33
  </ul>
34
- <p>The demo compares the model's uncertainty with two different probes:</p>
35
- <ul>
36
- <li><b>Semantic Uncertainty Probe:</b> Predicts the semantic uncertainty of the model's generations.</li>
37
- <li><b>Accuracy Probe:</b> Predicts the accuracy of the model's generations.</li>
38
- </ul>
39
  <p>Please see our paper for more details. NOTE: This demo is a work in progress.</p>
40
  """
41
 
@@ -49,7 +44,7 @@ EXAMPLES = [
49
  if torch.cuda.is_available():
50
  model_id = "meta-llama/Llama-2-7b-chat-hf"
51
  # TODO load the full model not the 8bit one?
52
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
53
  tokenizer = AutoTokenizer.from_pretrained(model_id)
54
  tokenizer.use_default_system_prompt = False
55
 
@@ -62,6 +57,7 @@ if torch.cuda.is_available():
62
  se_layer_range = probe_data['sep_layer_range']
63
  acc_probe = probe_data['t_amodel']
64
  acc_layer_range = probe_data['ap_layer_range']
 
65
  else:
66
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
67
 
@@ -74,7 +70,7 @@ def generate(
74
  top_p: float = 0.9,
75
  top_k: int = 50,
76
  repetition_penalty: float = 1.2,
77
- ) -> Tuple[str, str]:
78
  conversation = []
79
  if system_prompt:
80
  conversation.append({"role": "system", "content": system_prompt})
@@ -86,10 +82,7 @@ def generate(
86
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
87
  input_ids = input_ids.to(model.device)
88
 
89
- #### Generate without threading
90
  generation_kwargs = dict(
91
- input_ids=input_ids,
92
- max_new_tokens=max_new_tokens,
93
  do_sample=True,
94
  top_p=top_p,
95
  top_k=top_k,
@@ -98,40 +91,91 @@ def generate(
98
  output_hidden_states=True,
99
  return_dict_in_generate=True,
100
  )
101
- with torch.no_grad():
102
- outputs = model.generate(**generation_kwargs)
103
- generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
104
- generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
105
- print(generated_text)
106
- # hidden states
107
- hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
108
-
109
- se_highlighted_text = ""
110
- acc_highlighted_text = ""
111
 
112
- # skip the first hidden state as it is the prompt
113
- for i in range(1, len(hidden)):
114
 
115
- # Semantic Uncertainty Probe
116
- token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy() # (num_layers, hidden_size)
 
 
 
 
 
 
 
 
117
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
118
  se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
119
-
120
- # Accuracy Probe
121
  acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
122
- acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
123
-
124
- output_id = outputs.sequences[0, input_ids.shape[1]+i]
125
- output_word = tokenizer.decode(output_id)
126
- print(output_id, output_word, se_probe_pred, acc_probe_pred)
127
-
128
- se_new_highlighted_text = highlight_text(output_word, se_probe_pred)
129
- acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)
130
- se_highlighted_text += f" {se_new_highlighted_text}"
131
- acc_highlighted_text += f" {acc_new_highlighted_text}"
132
-
133
- return se_highlighted_text, acc_highlighted_text
134
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
 
137
  def highlight_text(text: str, uncertainty_score: float) -> str:
@@ -151,7 +195,8 @@ def highlight_text(text: str, uncertainty_score: float) -> str:
151
  html_color, text
152
  )
153
 
154
- with gr.Blocks(title="Llama-2 7B Chat with Dual Probes", css="footer {visibility: hidden}") as demo:
 
155
  gr.HTML(DESCRIPTION)
156
 
157
  with gr.Row():
@@ -168,34 +213,41 @@ with gr.Blocks(title="Llama-2 7B Chat with Dual Probes", css="footer {visibility
168
 
169
  with gr.Row():
170
  generate_btn = gr.Button("Generate")
 
171
  # Add spacing between probes
172
  gr.HTML("<br><br>")
173
 
174
- with gr.Row():
175
- with gr.Column():
 
 
 
176
  # make a box
177
- title = gr.HTML("<h2>Semantic Uncertainty Probe</h2>")
178
- se_output = gr.HTML(label="Semantic Uncertainty Probe")
179
-
180
  # Add spacing between columns
181
- gr.HTML("<div style='width: 20px;'></div>")
182
 
183
- with gr.Column():
184
- title = gr.HTML("<h2>Accuracy Probe</h2>")
185
- acc_output = gr.HTML(label="Accuracy Probe")
186
 
187
  gr.Examples(
188
  examples=EXAMPLES,
189
  inputs=[message, system_prompt],
190
- outputs=[se_output, acc_output],
 
191
  fn=generate,
192
  )
193
 
194
- generate_btn.click(
195
  generate,
196
  inputs=[message, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
197
- outputs=[se_output, acc_output]
 
198
  )
 
199
 
200
 
201
  if __name__ == "__main__":
 
2
  import pickle as pkl
3
  from pathlib import Path
4
  from threading import Thread
5
+ from typing import List, Tuple, Iterator, Optional, Generator
6
  from queue import Queue
7
 
8
  import spaces
 
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
13
+ # TODO this is not as fast as it could be using generate function with 1 token at a time
14
  # TODO log prob output scaling highlighting instead?
15
  # TODO make it look nicer
16
  # TODO better examples.
 
18
  # TODO add options to switch between models, SLT/TBG, layers?
19
  # TODO full semantic entropy calculation
20
 
21
+ MAX_MAX_NEW_TOKENS = 1024
22
+ DEFAULT_MAX_NEW_TOKENS = 100
23
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
24
 
25
  DESCRIPTION = """
 
31
  <li><span style="background-color: #00FF00; color: black">Green</span> indicates more certain generations</li>
32
  <li><span style="background-color: #FF0000; color: black">Red</span> indicates more uncertain generations</li>
33
  </ul>
 
 
 
 
 
34
  <p>Please see our paper for more details. NOTE: This demo is a work in progress.</p>
35
  """
36
 
 
44
  if torch.cuda.is_available():
45
  model_id = "meta-llama/Llama-2-7b-chat-hf"
46
  # TODO load the full model not the 8bit one?
47
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
48
  tokenizer = AutoTokenizer.from_pretrained(model_id)
49
  tokenizer.use_default_system_prompt = False
50
 
 
57
  se_layer_range = probe_data['sep_layer_range']
58
  acc_probe = probe_data['t_amodel']
59
  acc_layer_range = probe_data['ap_layer_range']
60
+ print(f"Loaded probes with layer ranges: {se_layer_range}, {acc_layer_range}")
61
  else:
62
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
63
 
 
70
  top_p: float = 0.9,
71
  top_k: int = 50,
72
  repetition_penalty: float = 1.2,
73
+ ) -> Generator[Tuple[str, str], None, None]:
74
  conversation = []
75
  if system_prompt:
76
  conversation.append({"role": "system", "content": system_prompt})
 
82
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
83
  input_ids = input_ids.to(model.device)
84
 
 
85
  generation_kwargs = dict(
 
 
86
  do_sample=True,
87
  top_p=top_p,
88
  top_k=top_k,
 
91
  output_hidden_states=True,
92
  return_dict_in_generate=True,
93
  )
94
+ sentence_start_idx = input_ids.shape[1]
95
+ sentence_token_count = 0
96
+ finished = False
 
 
 
 
 
 
 
97
 
 
 
98
 
99
+ with torch.no_grad():
100
+ # highlight and return the prompt
101
+ outputs = model.generate(**generation_kwargs, input_ids=input_ids, max_new_tokens=1)
102
+ prompt_tokens = outputs.sequences[0, :input_ids.shape[1]]
103
+ prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=True)
104
+ print(prompt_tokens, prompt_text)
105
+ # hidden states
106
+ hidden = outputs.hidden_states
107
+ # last token embeddings (note this is the same as the token before generation given this is the prompt)
108
+ token_embeddings = torch.stack([generated_token[0, -1, :].cpu() for generated_token in hidden[0]]).numpy()
109
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
110
  se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
 
 
111
  acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
112
+ acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][0] * 2 - 1 # accuracy probe is inverted wrt uncertainty
113
+ se_new_highlighted_text = highlight_text(prompt_text, se_probe_pred)
114
+ acc_new_highlighted_text = highlight_text(prompt_text, acc_probe_pred)
115
+ se_highlighted_text = f"{se_new_highlighted_text}<br>"
116
+ acc_highlighted_text = f"{acc_new_highlighted_text}<br>"
117
+
118
+ while not finished:
119
+ outputs = model.generate(**generation_kwargs, input_ids=input_ids, max_new_tokens=1)
120
+ # this should only be the one extra token (equivalent to -1)
121
+ generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
122
+ print(f"generated_tokens {generated_tokens}" )
123
+ # add to the conversation
124
+ input_ids = torch.cat([input_ids, generated_tokens.unsqueeze(0)], dim=-1)
125
+ # stop at the end of a sequence
126
+ if generated_tokens[-1] == tokenizer.eos_token_id or input_ids.shape[1] > max_new_tokens:
127
+ print("Finished")
128
+ finished = True
129
+ if generated_text != "":
130
+ # do final prediction on the last generated text (one before the eos token)
131
+ print("Predicting probes")
132
+ hidden = outputs.hidden_states # hidden states = (num generated tokens, num layers, batch size, num tokens, hidden size)
133
+ # last token embeddings
134
+ token_embeddings = torch.stack([generated_token[0, -2, :].cpu() for generated_token in hidden[-1]]).numpy()
135
+
136
+ se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
137
+ se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
138
+
139
+ acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
140
+ acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][0] * 2 - 1
141
+ print(f"se_probe_pred {se_probe_pred}, acc_probe_pred {acc_probe_pred}")
142
+
143
+ se_new_highlighted_text = highlight_text(generated_text, se_probe_pred)
144
+ acc_new_highlighted_text = highlight_text(generated_text, acc_probe_pred)
145
+ se_highlighted_text += f" {se_new_highlighted_text}"
146
+ acc_highlighted_text += f" {acc_new_highlighted_text}"
147
+ sentence_start_idx += sentence_token_count
148
+ sentence_token_count = 0
149
+
150
+ # decode the full generated text
151
+ generated_text = tokenizer.decode(outputs.sequences[0, sentence_start_idx:], skip_special_tokens=True)
152
+ print(f"generated_text: {generated_text}")
153
+ sentence_token_count += 1
154
+
155
+ # TODO this should be when a factoid is detected rather than just punctuation. Is the SLT token always basically a period for the probes?
156
+ if generated_text.endswith(('.', '!', '?', ';', '."', '!"', '?"')):
157
+ print("Predicting probes")
158
+ hidden = outputs.hidden_states # hidden states = (num generated tokens, num layers, batch size, num tokens, hidden size)
159
+ # last token embeddings
160
+ token_embeddings = torch.stack([generated_token[0, -1, :].cpu() for generated_token in hidden[-1]]).numpy()
161
+
162
+ se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
163
+ se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
164
+
165
+ acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
166
+ acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][0] * 2 - 1
167
+ print(f"se_probe_pred {se_probe_pred}, acc_probe_pred {acc_probe_pred}")
168
+
169
+ se_new_highlighted_text = highlight_text(generated_text, se_probe_pred)
170
+ acc_new_highlighted_text = highlight_text(generated_text, acc_probe_pred)
171
+ se_highlighted_text += f" {se_new_highlighted_text}"
172
+ acc_highlighted_text += f" {acc_new_highlighted_text}"
173
+ sentence_start_idx += sentence_token_count
174
+ sentence_token_count = 0
175
+ generated_text = ""
176
+
177
+ # yield se_highlighted_text + generated_text, acc_highlighted_text + generated_text
178
+ yield se_highlighted_text + generated_text #, acc_highlighted_text + generated_text
179
 
180
 
181
  def highlight_text(text: str, uncertainty_score: float) -> str:
 
195
  html_color, text
196
  )
197
 
198
+
199
+ with gr.Blocks(title="Llama-2 7B Chat with Semantic Uncertainty Probes", css="footer {visibility: hidden}") as demo:
200
  gr.HTML(DESCRIPTION)
201
 
202
  with gr.Row():
 
213
 
214
  with gr.Row():
215
  generate_btn = gr.Button("Generate")
216
+ stop_btn = gr.Button("Stop")
217
  # Add spacing between probes
218
  gr.HTML("<br><br>")
219
 
220
+ # with gr.Row():
221
+ with gr.Column():
222
+ title = gr.HTML("<h2>Semantic Uncertainty Probe</h2>")
223
+ se_output = gr.HTML(label="Semantic Uncertainty Probe")
224
+ # with gr.Column():
225
  # make a box
226
+ # title = gr.HTML("<h2>Semantic Uncertainty Probe</h2>")
227
+ # se_output = gr.HTML(label="Semantic Uncertainty Probe")
228
+
229
  # Add spacing between columns
230
+ # gr.HTML("<div style='width: 20px;'></div>")
231
 
232
+ # with gr.Column():
233
+ # title = gr.HTML("<h2>Accuracy Probe</h2>")
234
+ # acc_output = gr.HTML(label="Accuracy Probe")
235
 
236
  gr.Examples(
237
  examples=EXAMPLES,
238
  inputs=[message, system_prompt],
239
+ # outputs=[se_output, acc_output],
240
+ outputs=[se_output],
241
  fn=generate,
242
  )
243
 
244
+ generate_event = generate_btn.click(
245
  generate,
246
  inputs=[message, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
247
+ # outputs=[se_output, acc_output]
248
+ outputs=[se_output]
249
  )
250
+ stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[generate_event])
251
 
252
 
253
  if __name__ == "__main__":
debug.ipynb CHANGED
@@ -67,19 +67,21 @@
67
  "metadata": {},
68
  "outputs": [],
69
  "source": [
70
- "probe = probe_data['t_bmodel']\n",
71
- "layer_range = probe_data['sep_layer_range']"
 
 
72
  ]
73
  },
74
  {
75
  "cell_type": "code",
76
- "execution_count": 5,
77
  "metadata": {},
78
  "outputs": [
79
  {
80
  "data": {
81
  "application/vnd.jupyter.widget-view+json": {
82
- "model_id": "1c0e30b73cab48069e985203c598a9b0",
83
  "version_major": 2,
84
  "version_minor": 0
85
  },
@@ -89,6 +91,13 @@
89
  },
90
  "metadata": {},
91
  "output_type": "display_data"
 
 
 
 
 
 
 
92
  }
93
  ],
94
  "source": [
@@ -96,14 +105,145 @@
96
  "from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer\n",
97
  "\n",
98
  "model_id = \"meta-llama/Llama-2-7b-chat-hf\"\n",
99
- "model = AutoModelForCausalLM.from_pretrained(model_id, device_map=\"cpu\")\n",
100
  "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
101
  "tokenizer.use_default_system_prompt = False"
102
  ]
103
  },
104
  {
105
  "cell_type": "code",
106
- "execution_count": 8,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  "metadata": {},
108
  "outputs": [
109
  {
@@ -113,14 +253,23 @@
113
  "tensor([[ 1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492,\n",
114
  " 526, 263, 8444, 20255, 29889, 13, 29966, 829, 14816, 29903,\n",
115
  " 6778, 13, 13, 5816, 338, 278, 7483, 310, 3444, 29973,\n",
116
- " 518, 29914, 25580, 29962]]) torch.Size([1, 34])\n"
 
 
117
  ]
118
  },
119
  {
120
- "name": "stderr",
121
- "output_type": "stream",
122
- "text": [
123
- "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
 
 
 
 
 
 
 
124
  ]
125
  }
126
  ],
@@ -161,17 +310,22 @@
161
  "\n",
162
  "generated_text = \"\"\n",
163
  "highlighted_text = \"\"\n",
 
 
 
164
  "\n",
 
165
  "for new_text in streamer:\n",
166
  " print(new_text)\n",
167
  " generated_text += new_text\n",
 
168
  " current_input_ids = tokenizer.encode(generated_text, return_tensors=\"pt\").to(model.device)\n",
169
  " print(current_input_ids, current_input_ids.shape)\n",
170
  " with torch.no_grad():\n",
171
  " outputs = model(current_input_ids, output_hidden_states=True)\n",
172
- " print(outputs)\n",
173
  " hidden = outputs.hidden_states \n",
174
- " print(hidden.shape)\n",
 
175
  " # Stack second last token embeddings from all layers \n",
176
  " # if len(hidden) == 1: # FIX: runtime error for mistral-7b on bioasq\n",
177
  " # sec_last_input = hidden[0]\n",
@@ -179,9 +333,9 @@
179
  " # sec_last_input = hidden[-2]\n",
180
  " # else:\n",
181
  " # sec_last_input = hidden[n_generated - 2]\n",
182
- " # sec_last_token_embedding = torch.stack([layer[:, -1, :].cpu() for layer in sec_last_input])\n",
183
- " # print(sec_last_token_embedding.shape)\n",
184
- " last_hidden_state = outputs.hidden_states[-1][:, -1, :].cpu().numpy()\n",
185
  " print(last_hidden_state.shape) \n",
186
  " # TODO potentially need to only compute uncertainty for the last token in sentence?\n"
187
  ]
@@ -194,8 +348,7 @@
194
  "source": [
195
  "# concat hidden states\n",
196
  "\n",
197
- "\n",
198
- "hidden_states = np.concatenate(np.array(hidden_states)[layer_range], axis=1)\n",
199
  "# predict with probe\n",
200
  "pred = probe.predict(hidden_states)\n",
201
  "print(pred)"
 
67
  "metadata": {},
68
  "outputs": [],
69
  "source": [
70
+ "se_probe = probe_data['t_bmodel']\n",
71
+ "se_layer_range = probe_data['sep_layer_range']\n",
72
+ "acc_probe = probe_data['t_amodel']\n",
73
+ "acc_layer_range = probe_data['ap_layer_range']"
74
  ]
75
  },
76
  {
77
  "cell_type": "code",
78
+ "execution_count": 3,
79
  "metadata": {},
80
  "outputs": [
81
  {
82
  "data": {
83
  "application/vnd.jupyter.widget-view+json": {
84
+ "model_id": "30a1c8e576f6448bb228b4ae9a3a8a48",
85
  "version_major": 2,
86
  "version_minor": 0
87
  },
 
91
  },
92
  "metadata": {},
93
  "output_type": "display_data"
94
+ },
95
+ {
96
+ "name": "stderr",
97
+ "output_type": "stream",
98
+ "text": [
99
+ "Some parameters are on the meta device device because they were offloaded to the disk.\n"
100
+ ]
101
  }
102
  ],
103
  "source": [
 
105
  "from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer\n",
106
  "\n",
107
  "model_id = \"meta-llama/Llama-2-7b-chat-hf\"\n",
108
+ "model = AutoModelForCausalLM.from_pretrained(model_id, device_map=\"auto\")\n",
109
  "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
110
  "tokenizer.use_default_system_prompt = False"
111
  ]
112
  },
113
  {
114
  "cell_type": "code",
115
+ "execution_count": 6,
116
+ "metadata": {},
117
+ "outputs": [
118
+ {
119
+ "name": "stdout",
120
+ "output_type": "stream",
121
+ "text": [
122
+ "Љ ( \"ass\n",
123
+ "ЪЏ\n",
124
+ "հ MO-OC\n",
125
+ "tensor(30488, device='mps:0') Љ 1.0 -0.014414779243550946\n",
126
+ "tensor(313, device='mps:0') ( -0.9998164331881116 0.9597905489862286\n",
127
+ "tensor(376, device='mps:0') \" 0.9999998197256226 -0.9792630307582237\n",
128
+ "tensor(465, device='mps:0') ass -0.9999994897301452 0.9680999957882863\n",
129
+ "tensor(13, device='mps:0') \n",
130
+ " -0.99999964561314 0.9983907264450047\n",
131
+ "tensor(31147, device='mps:0') Ъ 1.0 -0.9999976710226259\n",
132
+ "tensor(30282, device='mps:0') Џ 1.0 0.9999912572082477\n",
133
+ "tensor(13, device='mps:0') \n",
134
+ " 0.9999999999869607 0.9999964462206883\n",
135
+ "tensor(31488, device='mps:0') հ 1.0 -1.0\n",
136
+ "tensor(341, device='mps:0') M 0.9045896738793786 0.5590883316684834\n",
137
+ "tensor(29949, device='mps:0') O -0.9999999803476437 -0.5270551643185932\n",
138
+ "tensor(29899, device='mps:0') - 0.9992488974195408 0.9987826119127319\n",
139
+ "tensor(29949, device='mps:0') O -0.9713693636571169 0.9993573968241007\n",
140
+ "tensor(29907, device='mps:0') C -0.9999999701427968 0.9904799691607524\n",
141
+ " <span style=\"background-color: #FF0000; color: black\">Љ</span> <span style=\"background-color: #00FF00; color: black\">(</span> <span style=\"background-color: #FF0000; color: black\">\"</span> <span style=\"background-color: #00FF00; color: black\">ass</span> <span style=\"background-color: #00FF00; color: black\">\n",
142
+ "</span> <span style=\"background-color: #FF0000; color: black\">Ъ</span> <span style=\"background-color: #FF0000; color: black\">Џ</span> <span style=\"background-color: #FF0000; color: black\">\n",
143
+ "</span> <span style=\"background-color: #FF0000; color: black\">հ</span> <span style=\"background-color: #FF1818; color: black\">M</span> <span style=\"background-color: #00FF00; color: black\">O</span> <span style=\"background-color: #FF0000; color: black\">-</span> <span style=\"background-color: #07FF07; color: black\">O</span> <span style=\"background-color: #00FF00; color: black\">C</span>\n"
144
+ ]
145
+ }
146
+ ],
147
+ "source": [
148
+ "from typing import Tuple\n",
149
+ "\n",
150
+ "MAX_INPUT_TOKEN_LENGTH = 512\n",
151
+ "\n",
152
+ "\n",
153
+ "def generate(\n",
154
+ " message: str,\n",
155
+ " system_prompt: str,\n",
156
+ " max_new_tokens: int = 10,\n",
157
+ " temperature: float = 0.6,\n",
158
+ " top_p: float = 0.9,\n",
159
+ " top_k: int = 50,\n",
160
+ " repetition_penalty: float = 1.2,\n",
161
+ ") -> Tuple[str, str]:\n",
162
+ " conversation = []\n",
163
+ " if system_prompt:\n",
164
+ " conversation.append({\"role\": \"system\", \"content\": system_prompt})\n",
165
+ " conversation.append({\"role\": \"user\", \"content\": message})\n",
166
+ "\n",
167
+ " input_ids = tokenizer.apply_chat_template(conversation, return_tensors=\"pt\")\n",
168
+ " if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:\n",
169
+ " input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]\n",
170
+ " input_ids = input_ids.to(model.device)\n",
171
+ "\n",
172
+ " #### Generate without threading\n",
173
+ " generation_kwargs = dict(\n",
174
+ " input_ids=input_ids,\n",
175
+ " max_new_tokens=max_new_tokens,\n",
176
+ " do_sample=True,\n",
177
+ " top_p=top_p,\n",
178
+ " top_k=top_k,\n",
179
+ " temperature=temperature,\n",
180
+ " repetition_penalty=repetition_penalty,\n",
181
+ " output_hidden_states=True,\n",
182
+ " return_dict_in_generate=True,\n",
183
+ " attention_mask=torch.ones_like(input_ids),\n",
184
+ " )\n",
185
+ " with torch.no_grad():\n",
186
+ " outputs = model.generate(**generation_kwargs)\n",
187
+ " generated_tokens = outputs.sequences[0, input_ids.shape[1]:]\n",
188
+ " generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)\n",
189
+ " print(generated_text)\n",
190
+ " # hidden states\n",
191
+ " hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)\n",
192
+ "\n",
193
+ " se_highlighted_text = \"\"\n",
194
+ " acc_highlighted_text = \"\"\n",
195
+ "\n",
196
+ " # skip the first hidden state as it is the prompt\n",
197
+ " for i in range(1, len(hidden)):\n",
198
+ "\n",
199
+ " # Semantic Uncertainty Probe\n",
200
+ " token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy() # (num_layers, hidden_size)\n",
201
+ " se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)\n",
202
+ " se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1\n",
203
+ " \n",
204
+ " # Accuracy Probe\n",
205
+ " acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)\n",
206
+ " acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1\n",
207
+ " \n",
208
+ " output_id = outputs.sequences[0, input_ids.shape[1]+i]\n",
209
+ " output_word = tokenizer.decode(output_id)\n",
210
+ " print(output_id, output_word, se_probe_pred, acc_probe_pred) \n",
211
+ "\n",
212
+ " se_new_highlighted_text = highlight_text(output_word, se_probe_pred)\n",
213
+ " acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)\n",
214
+ " se_highlighted_text += f\" {se_new_highlighted_text}\"\n",
215
+ " acc_highlighted_text += f\" {acc_new_highlighted_text}\"\n",
216
+ " \n",
217
+ " return se_highlighted_text, acc_highlighted_text\n",
218
+ "\n",
219
+ "\n",
220
+ "def highlight_text(text: str, uncertainty_score: float) -> str:\n",
221
+ " if uncertainty_score > 0:\n",
222
+ " html_color = \"#%02X%02X%02X\" % (\n",
223
+ " 255,\n",
224
+ " int(255 * (1 - uncertainty_score)),\n",
225
+ " int(255 * (1 - uncertainty_score)),\n",
226
+ " )\n",
227
+ " else:\n",
228
+ " html_color = \"#%02X%02X%02X\" % (\n",
229
+ " int(255 * (1 + uncertainty_score)),\n",
230
+ " 255,\n",
231
+ " int(255 * (1 + uncertainty_score)),\n",
232
+ " )\n",
233
+ " return '<span style=\"background-color: {}; color: black\">{}</span>'.format(\n",
234
+ " html_color, text\n",
235
+ " )\n",
236
+ "\n",
237
+ "message = \"What is the capital of France?\"\n",
238
+ "system_prompt = \"\"\n",
239
+ "se_highlighted_text, acc_highlighted_text = generate(message, system_prompt)\n",
240
+ "print(se_highlighted_text)\n",
241
+ " "
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 13,
247
  "metadata": {},
248
  "outputs": [
249
  {
 
253
  "tensor([[ 1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492,\n",
254
  " 526, 263, 8444, 20255, 29889, 13, 29966, 829, 14816, 29903,\n",
255
  " 6778, 13, 13, 5816, 338, 278, 7483, 310, 3444, 29973,\n",
256
+ " 518, 29914, 25580, 29962]]) torch.Size([1, 34])\n",
257
+ "\n",
258
+ " \n"
259
  ]
260
  },
261
  {
262
+ "ename": "KeyboardInterrupt",
263
+ "evalue": "",
264
+ "output_type": "error",
265
+ "traceback": [
266
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
267
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
268
+ "Cell \u001b[0;32mIn[13], line 37\u001b[0m\n\u001b[1;32m 35\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 36\u001b[0m highlighted_text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m output \u001b[38;5;129;01min\u001b[39;00m streamer:\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(output)\n\u001b[1;32m 39\u001b[0m generated_text \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m output\n",
269
+ "File \u001b[0;32m~/anaconda3/envs/llm-test/lib/python3.11/site-packages/transformers/generation/streamers.py:223\u001b[0m, in \u001b[0;36mTextIteratorStreamer.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__next__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 223\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtext_queue\u001b[38;5;241m.\u001b[39mget(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtimeout)\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m value \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstop_signal:\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m()\n",
270
+ "File \u001b[0;32m~/anaconda3/envs/llm-test/lib/python3.11/queue.py:180\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m remaining \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m:\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[0;32m--> 180\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnot_empty\u001b[38;5;241m.\u001b[39mwait(remaining)\n\u001b[1;32m 181\u001b[0m item \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get()\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnot_full\u001b[38;5;241m.\u001b[39mnotify()\n",
271
+ "File \u001b[0;32m~/anaconda3/envs/llm-test/lib/python3.11/threading.py:324\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 323\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 324\u001b[0m gotit \u001b[38;5;241m=\u001b[39m waiter\u001b[38;5;241m.\u001b[39macquire(\u001b[38;5;28;01mTrue\u001b[39;00m, timeout)\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 326\u001b[0m gotit \u001b[38;5;241m=\u001b[39m waiter\u001b[38;5;241m.\u001b[39macquire(\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
272
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
273
  ]
274
  }
275
  ],
 
310
  "\n",
311
  "generated_text = \"\"\n",
312
  "highlighted_text = \"\"\n",
313
+ "for output in streamer:\n",
314
+ " print(output)\n",
315
+ " generated_text += output\n",
316
  "\n",
317
+ " # yield generated_text\n",
318
  "for new_text in streamer:\n",
319
  " print(new_text)\n",
320
  " generated_text += new_text\n",
321
+ " print(generated_text)\n",
322
  " current_input_ids = tokenizer.encode(generated_text, return_tensors=\"pt\").to(model.device)\n",
323
  " print(current_input_ids, current_input_ids.shape)\n",
324
  " with torch.no_grad():\n",
325
  " outputs = model(current_input_ids, output_hidden_states=True)\n",
 
326
  " hidden = outputs.hidden_states \n",
327
+ " print(len(hidden))\n",
328
+ " print(hidden[-1].shape)\n",
329
  " # Stack second last token embeddings from all layers \n",
330
  " # if len(hidden) == 1: # FIX: runtime error for mistral-7b on bioasq\n",
331
  " # sec_last_input = hidden[0]\n",
 
333
  " # sec_last_input = hidden[-2]\n",
334
  " # else:\n",
335
  " # sec_last_input = hidden[n_generated - 2]\n",
336
+ " sec_last_token_embedding = torch.stack([layer[:, -1, :].cpu() for layer in hidden])\n",
337
+ " print(sec_last_token_embedding.shape)\n",
338
+ " last_hidden_state = hidden[-1][:, -1, :].cpu().numpy()\n",
339
  " print(last_hidden_state.shape) \n",
340
  " # TODO potentially need to only compute uncertainty for the last token in sentence?\n"
341
  ]
 
348
  "source": [
349
  "# concat hidden states\n",
350
  "\n",
351
+ "sec_last_token_embedding = np.concatenate(sec_last_token_embedding.cpu().numpy()[layer_range], axis=1)\n",
 
352
  "# predict with probe\n",
353
  "pred = probe.predict(hidden_states)\n",
354
  "print(pred)"