Spaces:
Running
on
Zero
Running
on
Zero
s-a-malik
commited on
Commit
•
bc61ed1
1
Parent(s):
bf84689
sentence level highlighting, remove acc probe for now
Browse files- __pycache__/app.cpython-311.pyc +0 -0
- app.py +108 -56
- debug.ipynb +171 -18
__pycache__/app.cpython-311.pyc
ADDED
Binary file (11.7 kB). View file
|
|
app.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import pickle as pkl
|
3 |
from pathlib import Path
|
4 |
from threading import Thread
|
5 |
-
from typing import List, Tuple, Iterator, Optional
|
6 |
from queue import Queue
|
7 |
|
8 |
import spaces
|
@@ -10,7 +10,7 @@ import gradio as gr
|
|
10 |
import torch
|
11 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
|
13 |
-
# TODO
|
14 |
# TODO log prob output scaling highlighting instead?
|
15 |
# TODO make it look nicer
|
16 |
# TODO better examples.
|
@@ -18,8 +18,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
18 |
# TODO add options to switch between models, SLT/TBG, layers?
|
19 |
# TODO full semantic entropy calculation
|
20 |
|
21 |
-
MAX_MAX_NEW_TOKENS =
|
22 |
-
DEFAULT_MAX_NEW_TOKENS =
|
23 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
24 |
|
25 |
DESCRIPTION = """
|
@@ -31,11 +31,6 @@ DESCRIPTION = """
|
|
31 |
<li><span style="background-color: #00FF00; color: black">Green</span> indicates more certain generations</li>
|
32 |
<li><span style="background-color: #FF0000; color: black">Red</span> indicates more uncertain generations</li>
|
33 |
</ul>
|
34 |
-
<p>The demo compares the model's uncertainty with two different probes:</p>
|
35 |
-
<ul>
|
36 |
-
<li><b>Semantic Uncertainty Probe:</b> Predicts the semantic uncertainty of the model's generations.</li>
|
37 |
-
<li><b>Accuracy Probe:</b> Predicts the accuracy of the model's generations.</li>
|
38 |
-
</ul>
|
39 |
<p>Please see our paper for more details. NOTE: This demo is a work in progress.</p>
|
40 |
"""
|
41 |
|
@@ -49,7 +44,7 @@ EXAMPLES = [
|
|
49 |
if torch.cuda.is_available():
|
50 |
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
51 |
# TODO load the full model not the 8bit one?
|
52 |
-
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto"
|
53 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
54 |
tokenizer.use_default_system_prompt = False
|
55 |
|
@@ -62,6 +57,7 @@ if torch.cuda.is_available():
|
|
62 |
se_layer_range = probe_data['sep_layer_range']
|
63 |
acc_probe = probe_data['t_amodel']
|
64 |
acc_layer_range = probe_data['ap_layer_range']
|
|
|
65 |
else:
|
66 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
67 |
|
@@ -74,7 +70,7 @@ def generate(
|
|
74 |
top_p: float = 0.9,
|
75 |
top_k: int = 50,
|
76 |
repetition_penalty: float = 1.2,
|
77 |
-
) -> Tuple[str, str]:
|
78 |
conversation = []
|
79 |
if system_prompt:
|
80 |
conversation.append({"role": "system", "content": system_prompt})
|
@@ -86,10 +82,7 @@ def generate(
|
|
86 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
87 |
input_ids = input_ids.to(model.device)
|
88 |
|
89 |
-
#### Generate without threading
|
90 |
generation_kwargs = dict(
|
91 |
-
input_ids=input_ids,
|
92 |
-
max_new_tokens=max_new_tokens,
|
93 |
do_sample=True,
|
94 |
top_p=top_p,
|
95 |
top_k=top_k,
|
@@ -98,40 +91,91 @@ def generate(
|
|
98 |
output_hidden_states=True,
|
99 |
return_dict_in_generate=True,
|
100 |
)
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
105 |
-
print(generated_text)
|
106 |
-
# hidden states
|
107 |
-
hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
|
108 |
-
|
109 |
-
se_highlighted_text = ""
|
110 |
-
acc_highlighted_text = ""
|
111 |
|
112 |
-
# skip the first hidden state as it is the prompt
|
113 |
-
for i in range(1, len(hidden)):
|
114 |
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
|
118 |
se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
|
119 |
-
|
120 |
-
# Accuracy Probe
|
121 |
acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
|
122 |
-
acc_probe_pred =
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
def highlight_text(text: str, uncertainty_score: float) -> str:
|
@@ -151,7 +195,8 @@ def highlight_text(text: str, uncertainty_score: float) -> str:
|
|
151 |
html_color, text
|
152 |
)
|
153 |
|
154 |
-
|
|
|
155 |
gr.HTML(DESCRIPTION)
|
156 |
|
157 |
with gr.Row():
|
@@ -168,34 +213,41 @@ with gr.Blocks(title="Llama-2 7B Chat with Dual Probes", css="footer {visibility
|
|
168 |
|
169 |
with gr.Row():
|
170 |
generate_btn = gr.Button("Generate")
|
|
|
171 |
# Add spacing between probes
|
172 |
gr.HTML("<br><br>")
|
173 |
|
174 |
-
with gr.Row():
|
175 |
-
|
|
|
|
|
|
|
176 |
# make a box
|
177 |
-
title = gr.HTML("<h2>Semantic Uncertainty Probe</h2>")
|
178 |
-
se_output = gr.HTML(label="Semantic Uncertainty Probe")
|
179 |
-
|
180 |
# Add spacing between columns
|
181 |
-
gr.HTML("<div style='width: 20px;'></div>")
|
182 |
|
183 |
-
with gr.Column():
|
184 |
-
title = gr.HTML("<h2>Accuracy Probe</h2>")
|
185 |
-
acc_output = gr.HTML(label="Accuracy Probe")
|
186 |
|
187 |
gr.Examples(
|
188 |
examples=EXAMPLES,
|
189 |
inputs=[message, system_prompt],
|
190 |
-
outputs=[se_output, acc_output],
|
|
|
191 |
fn=generate,
|
192 |
)
|
193 |
|
194 |
-
generate_btn.click(
|
195 |
generate,
|
196 |
inputs=[message, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
197 |
-
outputs=[se_output, acc_output]
|
|
|
198 |
)
|
|
|
199 |
|
200 |
|
201 |
if __name__ == "__main__":
|
|
|
2 |
import pickle as pkl
|
3 |
from pathlib import Path
|
4 |
from threading import Thread
|
5 |
+
from typing import List, Tuple, Iterator, Optional, Generator
|
6 |
from queue import Queue
|
7 |
|
8 |
import spaces
|
|
|
10 |
import torch
|
11 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
|
13 |
+
# TODO this is not as fast as it could be using generate function with 1 token at a time
|
14 |
# TODO log prob output scaling highlighting instead?
|
15 |
# TODO make it look nicer
|
16 |
# TODO better examples.
|
|
|
18 |
# TODO add options to switch between models, SLT/TBG, layers?
|
19 |
# TODO full semantic entropy calculation
|
20 |
|
21 |
+
MAX_MAX_NEW_TOKENS = 1024
|
22 |
+
DEFAULT_MAX_NEW_TOKENS = 100
|
23 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
24 |
|
25 |
DESCRIPTION = """
|
|
|
31 |
<li><span style="background-color: #00FF00; color: black">Green</span> indicates more certain generations</li>
|
32 |
<li><span style="background-color: #FF0000; color: black">Red</span> indicates more uncertain generations</li>
|
33 |
</ul>
|
|
|
|
|
|
|
|
|
|
|
34 |
<p>Please see our paper for more details. NOTE: This demo is a work in progress.</p>
|
35 |
"""
|
36 |
|
|
|
44 |
if torch.cuda.is_available():
|
45 |
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
46 |
# TODO load the full model not the 8bit one?
|
47 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
48 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
49 |
tokenizer.use_default_system_prompt = False
|
50 |
|
|
|
57 |
se_layer_range = probe_data['sep_layer_range']
|
58 |
acc_probe = probe_data['t_amodel']
|
59 |
acc_layer_range = probe_data['ap_layer_range']
|
60 |
+
print(f"Loaded probes with layer ranges: {se_layer_range}, {acc_layer_range}")
|
61 |
else:
|
62 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
63 |
|
|
|
70 |
top_p: float = 0.9,
|
71 |
top_k: int = 50,
|
72 |
repetition_penalty: float = 1.2,
|
73 |
+
) -> Generator[Tuple[str, str], None, None]:
|
74 |
conversation = []
|
75 |
if system_prompt:
|
76 |
conversation.append({"role": "system", "content": system_prompt})
|
|
|
82 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
83 |
input_ids = input_ids.to(model.device)
|
84 |
|
|
|
85 |
generation_kwargs = dict(
|
|
|
|
|
86 |
do_sample=True,
|
87 |
top_p=top_p,
|
88 |
top_k=top_k,
|
|
|
91 |
output_hidden_states=True,
|
92 |
return_dict_in_generate=True,
|
93 |
)
|
94 |
+
sentence_start_idx = input_ids.shape[1]
|
95 |
+
sentence_token_count = 0
|
96 |
+
finished = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
|
|
|
|
98 |
|
99 |
+
with torch.no_grad():
|
100 |
+
# highlight and return the prompt
|
101 |
+
outputs = model.generate(**generation_kwargs, input_ids=input_ids, max_new_tokens=1)
|
102 |
+
prompt_tokens = outputs.sequences[0, :input_ids.shape[1]]
|
103 |
+
prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=True)
|
104 |
+
print(prompt_tokens, prompt_text)
|
105 |
+
# hidden states
|
106 |
+
hidden = outputs.hidden_states
|
107 |
+
# last token embeddings (note this is the same as the token before generation given this is the prompt)
|
108 |
+
token_embeddings = torch.stack([generated_token[0, -1, :].cpu() for generated_token in hidden[0]]).numpy()
|
109 |
se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
|
110 |
se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
|
|
|
|
|
111 |
acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
|
112 |
+
acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][0] * 2 - 1 # accuracy probe is inverted wrt uncertainty
|
113 |
+
se_new_highlighted_text = highlight_text(prompt_text, se_probe_pred)
|
114 |
+
acc_new_highlighted_text = highlight_text(prompt_text, acc_probe_pred)
|
115 |
+
se_highlighted_text = f"{se_new_highlighted_text}<br>"
|
116 |
+
acc_highlighted_text = f"{acc_new_highlighted_text}<br>"
|
117 |
+
|
118 |
+
while not finished:
|
119 |
+
outputs = model.generate(**generation_kwargs, input_ids=input_ids, max_new_tokens=1)
|
120 |
+
# this should only be the one extra token (equivalent to -1)
|
121 |
+
generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
|
122 |
+
print(f"generated_tokens {generated_tokens}" )
|
123 |
+
# add to the conversation
|
124 |
+
input_ids = torch.cat([input_ids, generated_tokens.unsqueeze(0)], dim=-1)
|
125 |
+
# stop at the end of a sequence
|
126 |
+
if generated_tokens[-1] == tokenizer.eos_token_id or input_ids.shape[1] > max_new_tokens:
|
127 |
+
print("Finished")
|
128 |
+
finished = True
|
129 |
+
if generated_text != "":
|
130 |
+
# do final prediction on the last generated text (one before the eos token)
|
131 |
+
print("Predicting probes")
|
132 |
+
hidden = outputs.hidden_states # hidden states = (num generated tokens, num layers, batch size, num tokens, hidden size)
|
133 |
+
# last token embeddings
|
134 |
+
token_embeddings = torch.stack([generated_token[0, -2, :].cpu() for generated_token in hidden[-1]]).numpy()
|
135 |
+
|
136 |
+
se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
|
137 |
+
se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
|
138 |
+
|
139 |
+
acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
|
140 |
+
acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][0] * 2 - 1
|
141 |
+
print(f"se_probe_pred {se_probe_pred}, acc_probe_pred {acc_probe_pred}")
|
142 |
+
|
143 |
+
se_new_highlighted_text = highlight_text(generated_text, se_probe_pred)
|
144 |
+
acc_new_highlighted_text = highlight_text(generated_text, acc_probe_pred)
|
145 |
+
se_highlighted_text += f" {se_new_highlighted_text}"
|
146 |
+
acc_highlighted_text += f" {acc_new_highlighted_text}"
|
147 |
+
sentence_start_idx += sentence_token_count
|
148 |
+
sentence_token_count = 0
|
149 |
+
|
150 |
+
# decode the full generated text
|
151 |
+
generated_text = tokenizer.decode(outputs.sequences[0, sentence_start_idx:], skip_special_tokens=True)
|
152 |
+
print(f"generated_text: {generated_text}")
|
153 |
+
sentence_token_count += 1
|
154 |
+
|
155 |
+
# TODO this should be when a factoid is detected rather than just punctuation. Is the SLT token always basically a period for the probes?
|
156 |
+
if generated_text.endswith(('.', '!', '?', ';', '."', '!"', '?"')):
|
157 |
+
print("Predicting probes")
|
158 |
+
hidden = outputs.hidden_states # hidden states = (num generated tokens, num layers, batch size, num tokens, hidden size)
|
159 |
+
# last token embeddings
|
160 |
+
token_embeddings = torch.stack([generated_token[0, -1, :].cpu() for generated_token in hidden[-1]]).numpy()
|
161 |
+
|
162 |
+
se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
|
163 |
+
se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
|
164 |
+
|
165 |
+
acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
|
166 |
+
acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][0] * 2 - 1
|
167 |
+
print(f"se_probe_pred {se_probe_pred}, acc_probe_pred {acc_probe_pred}")
|
168 |
+
|
169 |
+
se_new_highlighted_text = highlight_text(generated_text, se_probe_pred)
|
170 |
+
acc_new_highlighted_text = highlight_text(generated_text, acc_probe_pred)
|
171 |
+
se_highlighted_text += f" {se_new_highlighted_text}"
|
172 |
+
acc_highlighted_text += f" {acc_new_highlighted_text}"
|
173 |
+
sentence_start_idx += sentence_token_count
|
174 |
+
sentence_token_count = 0
|
175 |
+
generated_text = ""
|
176 |
+
|
177 |
+
# yield se_highlighted_text + generated_text, acc_highlighted_text + generated_text
|
178 |
+
yield se_highlighted_text + generated_text #, acc_highlighted_text + generated_text
|
179 |
|
180 |
|
181 |
def highlight_text(text: str, uncertainty_score: float) -> str:
|
|
|
195 |
html_color, text
|
196 |
)
|
197 |
|
198 |
+
|
199 |
+
with gr.Blocks(title="Llama-2 7B Chat with Semantic Uncertainty Probes", css="footer {visibility: hidden}") as demo:
|
200 |
gr.HTML(DESCRIPTION)
|
201 |
|
202 |
with gr.Row():
|
|
|
213 |
|
214 |
with gr.Row():
|
215 |
generate_btn = gr.Button("Generate")
|
216 |
+
stop_btn = gr.Button("Stop")
|
217 |
# Add spacing between probes
|
218 |
gr.HTML("<br><br>")
|
219 |
|
220 |
+
# with gr.Row():
|
221 |
+
with gr.Column():
|
222 |
+
title = gr.HTML("<h2>Semantic Uncertainty Probe</h2>")
|
223 |
+
se_output = gr.HTML(label="Semantic Uncertainty Probe")
|
224 |
+
# with gr.Column():
|
225 |
# make a box
|
226 |
+
# title = gr.HTML("<h2>Semantic Uncertainty Probe</h2>")
|
227 |
+
# se_output = gr.HTML(label="Semantic Uncertainty Probe")
|
228 |
+
|
229 |
# Add spacing between columns
|
230 |
+
# gr.HTML("<div style='width: 20px;'></div>")
|
231 |
|
232 |
+
# with gr.Column():
|
233 |
+
# title = gr.HTML("<h2>Accuracy Probe</h2>")
|
234 |
+
# acc_output = gr.HTML(label="Accuracy Probe")
|
235 |
|
236 |
gr.Examples(
|
237 |
examples=EXAMPLES,
|
238 |
inputs=[message, system_prompt],
|
239 |
+
# outputs=[se_output, acc_output],
|
240 |
+
outputs=[se_output],
|
241 |
fn=generate,
|
242 |
)
|
243 |
|
244 |
+
generate_event = generate_btn.click(
|
245 |
generate,
|
246 |
inputs=[message, system_prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
247 |
+
# outputs=[se_output, acc_output]
|
248 |
+
outputs=[se_output]
|
249 |
)
|
250 |
+
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[generate_event])
|
251 |
|
252 |
|
253 |
if __name__ == "__main__":
|
debug.ipynb
CHANGED
@@ -67,19 +67,21 @@
|
|
67 |
"metadata": {},
|
68 |
"outputs": [],
|
69 |
"source": [
|
70 |
-
"
|
71 |
-
"
|
|
|
|
|
72 |
]
|
73 |
},
|
74 |
{
|
75 |
"cell_type": "code",
|
76 |
-
"execution_count":
|
77 |
"metadata": {},
|
78 |
"outputs": [
|
79 |
{
|
80 |
"data": {
|
81 |
"application/vnd.jupyter.widget-view+json": {
|
82 |
-
"model_id": "
|
83 |
"version_major": 2,
|
84 |
"version_minor": 0
|
85 |
},
|
@@ -89,6 +91,13 @@
|
|
89 |
},
|
90 |
"metadata": {},
|
91 |
"output_type": "display_data"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
}
|
93 |
],
|
94 |
"source": [
|
@@ -96,14 +105,145 @@
|
|
96 |
"from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer\n",
|
97 |
"\n",
|
98 |
"model_id = \"meta-llama/Llama-2-7b-chat-hf\"\n",
|
99 |
-
"model = AutoModelForCausalLM.from_pretrained(model_id, device_map=\"
|
100 |
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
101 |
"tokenizer.use_default_system_prompt = False"
|
102 |
]
|
103 |
},
|
104 |
{
|
105 |
"cell_type": "code",
|
106 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
"metadata": {},
|
108 |
"outputs": [
|
109 |
{
|
@@ -113,14 +253,23 @@
|
|
113 |
"tensor([[ 1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492,\n",
|
114 |
" 526, 263, 8444, 20255, 29889, 13, 29966, 829, 14816, 29903,\n",
|
115 |
" 6778, 13, 13, 5816, 338, 278, 7483, 310, 3444, 29973,\n",
|
116 |
-
" 518, 29914, 25580, 29962]]) torch.Size([1, 34])\n"
|
|
|
|
|
117 |
]
|
118 |
},
|
119 |
{
|
120 |
-
"
|
121 |
-
"
|
122 |
-
"
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
]
|
125 |
}
|
126 |
],
|
@@ -161,17 +310,22 @@
|
|
161 |
"\n",
|
162 |
"generated_text = \"\"\n",
|
163 |
"highlighted_text = \"\"\n",
|
|
|
|
|
|
|
164 |
"\n",
|
|
|
165 |
"for new_text in streamer:\n",
|
166 |
" print(new_text)\n",
|
167 |
" generated_text += new_text\n",
|
|
|
168 |
" current_input_ids = tokenizer.encode(generated_text, return_tensors=\"pt\").to(model.device)\n",
|
169 |
" print(current_input_ids, current_input_ids.shape)\n",
|
170 |
" with torch.no_grad():\n",
|
171 |
" outputs = model(current_input_ids, output_hidden_states=True)\n",
|
172 |
-
" print(outputs)\n",
|
173 |
" hidden = outputs.hidden_states \n",
|
174 |
-
" print(hidden
|
|
|
175 |
" # Stack second last token embeddings from all layers \n",
|
176 |
" # if len(hidden) == 1: # FIX: runtime error for mistral-7b on bioasq\n",
|
177 |
" # sec_last_input = hidden[0]\n",
|
@@ -179,9 +333,9 @@
|
|
179 |
" # sec_last_input = hidden[-2]\n",
|
180 |
" # else:\n",
|
181 |
" # sec_last_input = hidden[n_generated - 2]\n",
|
182 |
-
"
|
183 |
-
"
|
184 |
-
" last_hidden_state =
|
185 |
" print(last_hidden_state.shape) \n",
|
186 |
" # TODO potentially need to only compute uncertainty for the last token in sentence?\n"
|
187 |
]
|
@@ -194,8 +348,7 @@
|
|
194 |
"source": [
|
195 |
"# concat hidden states\n",
|
196 |
"\n",
|
197 |
-
"\n",
|
198 |
-
"hidden_states = np.concatenate(np.array(hidden_states)[layer_range], axis=1)\n",
|
199 |
"# predict with probe\n",
|
200 |
"pred = probe.predict(hidden_states)\n",
|
201 |
"print(pred)"
|
|
|
67 |
"metadata": {},
|
68 |
"outputs": [],
|
69 |
"source": [
|
70 |
+
"se_probe = probe_data['t_bmodel']\n",
|
71 |
+
"se_layer_range = probe_data['sep_layer_range']\n",
|
72 |
+
"acc_probe = probe_data['t_amodel']\n",
|
73 |
+
"acc_layer_range = probe_data['ap_layer_range']"
|
74 |
]
|
75 |
},
|
76 |
{
|
77 |
"cell_type": "code",
|
78 |
+
"execution_count": 3,
|
79 |
"metadata": {},
|
80 |
"outputs": [
|
81 |
{
|
82 |
"data": {
|
83 |
"application/vnd.jupyter.widget-view+json": {
|
84 |
+
"model_id": "30a1c8e576f6448bb228b4ae9a3a8a48",
|
85 |
"version_major": 2,
|
86 |
"version_minor": 0
|
87 |
},
|
|
|
91 |
},
|
92 |
"metadata": {},
|
93 |
"output_type": "display_data"
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"name": "stderr",
|
97 |
+
"output_type": "stream",
|
98 |
+
"text": [
|
99 |
+
"Some parameters are on the meta device device because they were offloaded to the disk.\n"
|
100 |
+
]
|
101 |
}
|
102 |
],
|
103 |
"source": [
|
|
|
105 |
"from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer\n",
|
106 |
"\n",
|
107 |
"model_id = \"meta-llama/Llama-2-7b-chat-hf\"\n",
|
108 |
+
"model = AutoModelForCausalLM.from_pretrained(model_id, device_map=\"auto\")\n",
|
109 |
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
110 |
"tokenizer.use_default_system_prompt = False"
|
111 |
]
|
112 |
},
|
113 |
{
|
114 |
"cell_type": "code",
|
115 |
+
"execution_count": 6,
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [
|
118 |
+
{
|
119 |
+
"name": "stdout",
|
120 |
+
"output_type": "stream",
|
121 |
+
"text": [
|
122 |
+
"Љ ( \"ass\n",
|
123 |
+
"ЪЏ\n",
|
124 |
+
"հ MO-OC\n",
|
125 |
+
"tensor(30488, device='mps:0') Љ 1.0 -0.014414779243550946\n",
|
126 |
+
"tensor(313, device='mps:0') ( -0.9998164331881116 0.9597905489862286\n",
|
127 |
+
"tensor(376, device='mps:0') \" 0.9999998197256226 -0.9792630307582237\n",
|
128 |
+
"tensor(465, device='mps:0') ass -0.9999994897301452 0.9680999957882863\n",
|
129 |
+
"tensor(13, device='mps:0') \n",
|
130 |
+
" -0.99999964561314 0.9983907264450047\n",
|
131 |
+
"tensor(31147, device='mps:0') Ъ 1.0 -0.9999976710226259\n",
|
132 |
+
"tensor(30282, device='mps:0') Џ 1.0 0.9999912572082477\n",
|
133 |
+
"tensor(13, device='mps:0') \n",
|
134 |
+
" 0.9999999999869607 0.9999964462206883\n",
|
135 |
+
"tensor(31488, device='mps:0') հ 1.0 -1.0\n",
|
136 |
+
"tensor(341, device='mps:0') M 0.9045896738793786 0.5590883316684834\n",
|
137 |
+
"tensor(29949, device='mps:0') O -0.9999999803476437 -0.5270551643185932\n",
|
138 |
+
"tensor(29899, device='mps:0') - 0.9992488974195408 0.9987826119127319\n",
|
139 |
+
"tensor(29949, device='mps:0') O -0.9713693636571169 0.9993573968241007\n",
|
140 |
+
"tensor(29907, device='mps:0') C -0.9999999701427968 0.9904799691607524\n",
|
141 |
+
" <span style=\"background-color: #FF0000; color: black\">Љ</span> <span style=\"background-color: #00FF00; color: black\">(</span> <span style=\"background-color: #FF0000; color: black\">\"</span> <span style=\"background-color: #00FF00; color: black\">ass</span> <span style=\"background-color: #00FF00; color: black\">\n",
|
142 |
+
"</span> <span style=\"background-color: #FF0000; color: black\">Ъ</span> <span style=\"background-color: #FF0000; color: black\">Џ</span> <span style=\"background-color: #FF0000; color: black\">\n",
|
143 |
+
"</span> <span style=\"background-color: #FF0000; color: black\">հ</span> <span style=\"background-color: #FF1818; color: black\">M</span> <span style=\"background-color: #00FF00; color: black\">O</span> <span style=\"background-color: #FF0000; color: black\">-</span> <span style=\"background-color: #07FF07; color: black\">O</span> <span style=\"background-color: #00FF00; color: black\">C</span>\n"
|
144 |
+
]
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"source": [
|
148 |
+
"from typing import Tuple\n",
|
149 |
+
"\n",
|
150 |
+
"MAX_INPUT_TOKEN_LENGTH = 512\n",
|
151 |
+
"\n",
|
152 |
+
"\n",
|
153 |
+
"def generate(\n",
|
154 |
+
" message: str,\n",
|
155 |
+
" system_prompt: str,\n",
|
156 |
+
" max_new_tokens: int = 10,\n",
|
157 |
+
" temperature: float = 0.6,\n",
|
158 |
+
" top_p: float = 0.9,\n",
|
159 |
+
" top_k: int = 50,\n",
|
160 |
+
" repetition_penalty: float = 1.2,\n",
|
161 |
+
") -> Tuple[str, str]:\n",
|
162 |
+
" conversation = []\n",
|
163 |
+
" if system_prompt:\n",
|
164 |
+
" conversation.append({\"role\": \"system\", \"content\": system_prompt})\n",
|
165 |
+
" conversation.append({\"role\": \"user\", \"content\": message})\n",
|
166 |
+
"\n",
|
167 |
+
" input_ids = tokenizer.apply_chat_template(conversation, return_tensors=\"pt\")\n",
|
168 |
+
" if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:\n",
|
169 |
+
" input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]\n",
|
170 |
+
" input_ids = input_ids.to(model.device)\n",
|
171 |
+
"\n",
|
172 |
+
" #### Generate without threading\n",
|
173 |
+
" generation_kwargs = dict(\n",
|
174 |
+
" input_ids=input_ids,\n",
|
175 |
+
" max_new_tokens=max_new_tokens,\n",
|
176 |
+
" do_sample=True,\n",
|
177 |
+
" top_p=top_p,\n",
|
178 |
+
" top_k=top_k,\n",
|
179 |
+
" temperature=temperature,\n",
|
180 |
+
" repetition_penalty=repetition_penalty,\n",
|
181 |
+
" output_hidden_states=True,\n",
|
182 |
+
" return_dict_in_generate=True,\n",
|
183 |
+
" attention_mask=torch.ones_like(input_ids),\n",
|
184 |
+
" )\n",
|
185 |
+
" with torch.no_grad():\n",
|
186 |
+
" outputs = model.generate(**generation_kwargs)\n",
|
187 |
+
" generated_tokens = outputs.sequences[0, input_ids.shape[1]:]\n",
|
188 |
+
" generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)\n",
|
189 |
+
" print(generated_text)\n",
|
190 |
+
" # hidden states\n",
|
191 |
+
" hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)\n",
|
192 |
+
"\n",
|
193 |
+
" se_highlighted_text = \"\"\n",
|
194 |
+
" acc_highlighted_text = \"\"\n",
|
195 |
+
"\n",
|
196 |
+
" # skip the first hidden state as it is the prompt\n",
|
197 |
+
" for i in range(1, len(hidden)):\n",
|
198 |
+
"\n",
|
199 |
+
" # Semantic Uncertainty Probe\n",
|
200 |
+
" token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy() # (num_layers, hidden_size)\n",
|
201 |
+
" se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)\n",
|
202 |
+
" se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1\n",
|
203 |
+
" \n",
|
204 |
+
" # Accuracy Probe\n",
|
205 |
+
" acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)\n",
|
206 |
+
" acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1\n",
|
207 |
+
" \n",
|
208 |
+
" output_id = outputs.sequences[0, input_ids.shape[1]+i]\n",
|
209 |
+
" output_word = tokenizer.decode(output_id)\n",
|
210 |
+
" print(output_id, output_word, se_probe_pred, acc_probe_pred) \n",
|
211 |
+
"\n",
|
212 |
+
" se_new_highlighted_text = highlight_text(output_word, se_probe_pred)\n",
|
213 |
+
" acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)\n",
|
214 |
+
" se_highlighted_text += f\" {se_new_highlighted_text}\"\n",
|
215 |
+
" acc_highlighted_text += f\" {acc_new_highlighted_text}\"\n",
|
216 |
+
" \n",
|
217 |
+
" return se_highlighted_text, acc_highlighted_text\n",
|
218 |
+
"\n",
|
219 |
+
"\n",
|
220 |
+
"def highlight_text(text: str, uncertainty_score: float) -> str:\n",
|
221 |
+
" if uncertainty_score > 0:\n",
|
222 |
+
" html_color = \"#%02X%02X%02X\" % (\n",
|
223 |
+
" 255,\n",
|
224 |
+
" int(255 * (1 - uncertainty_score)),\n",
|
225 |
+
" int(255 * (1 - uncertainty_score)),\n",
|
226 |
+
" )\n",
|
227 |
+
" else:\n",
|
228 |
+
" html_color = \"#%02X%02X%02X\" % (\n",
|
229 |
+
" int(255 * (1 + uncertainty_score)),\n",
|
230 |
+
" 255,\n",
|
231 |
+
" int(255 * (1 + uncertainty_score)),\n",
|
232 |
+
" )\n",
|
233 |
+
" return '<span style=\"background-color: {}; color: black\">{}</span>'.format(\n",
|
234 |
+
" html_color, text\n",
|
235 |
+
" )\n",
|
236 |
+
"\n",
|
237 |
+
"message = \"What is the capital of France?\"\n",
|
238 |
+
"system_prompt = \"\"\n",
|
239 |
+
"se_highlighted_text, acc_highlighted_text = generate(message, system_prompt)\n",
|
240 |
+
"print(se_highlighted_text)\n",
|
241 |
+
" "
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": 13,
|
247 |
"metadata": {},
|
248 |
"outputs": [
|
249 |
{
|
|
|
253 |
"tensor([[ 1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492,\n",
|
254 |
" 526, 263, 8444, 20255, 29889, 13, 29966, 829, 14816, 29903,\n",
|
255 |
" 6778, 13, 13, 5816, 338, 278, 7483, 310, 3444, 29973,\n",
|
256 |
+
" 518, 29914, 25580, 29962]]) torch.Size([1, 34])\n",
|
257 |
+
"\n",
|
258 |
+
" \n"
|
259 |
]
|
260 |
},
|
261 |
{
|
262 |
+
"ename": "KeyboardInterrupt",
|
263 |
+
"evalue": "",
|
264 |
+
"output_type": "error",
|
265 |
+
"traceback": [
|
266 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
267 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
268 |
+
"Cell \u001b[0;32mIn[13], line 37\u001b[0m\n\u001b[1;32m 35\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 36\u001b[0m highlighted_text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m output \u001b[38;5;129;01min\u001b[39;00m streamer:\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(output)\n\u001b[1;32m 39\u001b[0m generated_text \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m output\n",
|
269 |
+
"File \u001b[0;32m~/anaconda3/envs/llm-test/lib/python3.11/site-packages/transformers/generation/streamers.py:223\u001b[0m, in \u001b[0;36mTextIteratorStreamer.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__next__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 223\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtext_queue\u001b[38;5;241m.\u001b[39mget(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtimeout)\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m value \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstop_signal:\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m()\n",
|
270 |
+
"File \u001b[0;32m~/anaconda3/envs/llm-test/lib/python3.11/queue.py:180\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m remaining \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m:\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[0;32m--> 180\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnot_empty\u001b[38;5;241m.\u001b[39mwait(remaining)\n\u001b[1;32m 181\u001b[0m item \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get()\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnot_full\u001b[38;5;241m.\u001b[39mnotify()\n",
|
271 |
+
"File \u001b[0;32m~/anaconda3/envs/llm-test/lib/python3.11/threading.py:324\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 323\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 324\u001b[0m gotit \u001b[38;5;241m=\u001b[39m waiter\u001b[38;5;241m.\u001b[39macquire(\u001b[38;5;28;01mTrue\u001b[39;00m, timeout)\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 326\u001b[0m gotit \u001b[38;5;241m=\u001b[39m waiter\u001b[38;5;241m.\u001b[39macquire(\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
272 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
273 |
]
|
274 |
}
|
275 |
],
|
|
|
310 |
"\n",
|
311 |
"generated_text = \"\"\n",
|
312 |
"highlighted_text = \"\"\n",
|
313 |
+
"for output in streamer:\n",
|
314 |
+
" print(output)\n",
|
315 |
+
" generated_text += output\n",
|
316 |
"\n",
|
317 |
+
" # yield generated_text\n",
|
318 |
"for new_text in streamer:\n",
|
319 |
" print(new_text)\n",
|
320 |
" generated_text += new_text\n",
|
321 |
+
" print(generated_text)\n",
|
322 |
" current_input_ids = tokenizer.encode(generated_text, return_tensors=\"pt\").to(model.device)\n",
|
323 |
" print(current_input_ids, current_input_ids.shape)\n",
|
324 |
" with torch.no_grad():\n",
|
325 |
" outputs = model(current_input_ids, output_hidden_states=True)\n",
|
|
|
326 |
" hidden = outputs.hidden_states \n",
|
327 |
+
" print(len(hidden))\n",
|
328 |
+
" print(hidden[-1].shape)\n",
|
329 |
" # Stack second last token embeddings from all layers \n",
|
330 |
" # if len(hidden) == 1: # FIX: runtime error for mistral-7b on bioasq\n",
|
331 |
" # sec_last_input = hidden[0]\n",
|
|
|
333 |
" # sec_last_input = hidden[-2]\n",
|
334 |
" # else:\n",
|
335 |
" # sec_last_input = hidden[n_generated - 2]\n",
|
336 |
+
" sec_last_token_embedding = torch.stack([layer[:, -1, :].cpu() for layer in hidden])\n",
|
337 |
+
" print(sec_last_token_embedding.shape)\n",
|
338 |
+
" last_hidden_state = hidden[-1][:, -1, :].cpu().numpy()\n",
|
339 |
" print(last_hidden_state.shape) \n",
|
340 |
" # TODO potentially need to only compute uncertainty for the last token in sentence?\n"
|
341 |
]
|
|
|
348 |
"source": [
|
349 |
"# concat hidden states\n",
|
350 |
"\n",
|
351 |
+
"sec_last_token_embedding = np.concatenate(sec_last_token_embedding.cpu().numpy()[layer_range], axis=1)\n",
|
|
|
352 |
"# predict with probe\n",
|
353 |
"pred = probe.predict(hidden_states)\n",
|
354 |
"print(pred)"
|