Spaces:
Sleeping
Sleeping
Refactor analysis process
Browse files- analyse.py +385 -327
- api.py +7 -2
- config.ini +2 -1
- demo.py +23 -6
- global_config.py +4 -1
- main.py +13 -0
- processors.py +7 -7
- schemes.py +1 -1
- seed_schemes/dummy_hash.py +2 -0
- stegno.py +19 -14
analyse.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import json
|
3 |
import base64
|
4 |
from argparse import ArgumentParser
|
@@ -16,50 +18,27 @@ rng = torch.Generator(device="cpu")
|
|
16 |
rng.manual_seed(0)
|
17 |
|
18 |
|
19 |
-
def load_msgs(msg_lens: list[int]
|
20 |
-
msgs =
|
21 |
-
if file is not None and os.path.isfile(file):
|
22 |
-
with open(file, "r") as f:
|
23 |
-
msgs = json.load(f)
|
24 |
-
if "readable" not in msgs and "random" not in msgs:
|
25 |
-
msgs = None
|
26 |
-
else:
|
27 |
-
return msgs
|
28 |
-
|
29 |
-
msgs = {
|
30 |
-
"readable": [],
|
31 |
-
"random": [],
|
32 |
-
}
|
33 |
-
|
34 |
c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True)
|
35 |
iterator = iter(c4_en)
|
36 |
|
37 |
for length in tqdm(msg_lens, desc="Loading messages"):
|
38 |
random_msg = torch.randint(256, (length,), generator=rng)
|
39 |
-
|
40 |
-
"ascii"
|
41 |
-
)
|
42 |
-
msgs["random"].append(base64_msg)
|
43 |
|
44 |
while True:
|
45 |
readable_msg = next(iterator)["text"]
|
46 |
try:
|
47 |
-
readable_msg[:length].encode("ascii")
|
48 |
break
|
49 |
except Exception as e:
|
50 |
continue
|
51 |
-
msgs["readable"].append(readable_msg[:length])
|
52 |
|
53 |
return msgs
|
54 |
|
55 |
|
56 |
-
def load_prompts(n: int, prompt_size: int
|
57 |
-
prompts = None
|
58 |
-
if file is not None and os.path.isfile(file):
|
59 |
-
with open(file, "r") as f:
|
60 |
-
prompts = json.load(f)
|
61 |
-
return prompts
|
62 |
-
|
63 |
prompts = []
|
64 |
|
65 |
c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True)
|
@@ -68,14 +47,326 @@ def load_prompts(n: int, prompt_size: int, file: str | None = None):
|
|
68 |
with tqdm(total=n, desc="Loading prompts") as pbar:
|
69 |
while len(prompts) < n:
|
70 |
text = next(iterator)["text"]
|
71 |
-
|
|
|
72 |
continue
|
73 |
-
|
|
|
|
|
|
|
74 |
pbar.update()
|
75 |
|
76 |
return prompts
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def create_args():
|
80 |
parser = ArgumentParser()
|
81 |
|
@@ -105,7 +396,7 @@ def create_args():
|
|
105 |
parser.add_argument(
|
106 |
"--num-prompts",
|
107 |
type=int,
|
108 |
-
default=
|
109 |
help="Number of prompts",
|
110 |
)
|
111 |
parser.add_argument(
|
@@ -128,6 +419,12 @@ def create_args():
|
|
128 |
default="gpt2",
|
129 |
help="Model used to generate",
|
130 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
parser.add_argument(
|
132 |
"--deltas",
|
133 |
nargs=3,
|
@@ -140,12 +437,23 @@ def create_args():
|
|
140 |
type=int,
|
141 |
help="Bases used in base encoding",
|
142 |
)
|
|
|
|
|
143 |
parser.add_argument(
|
144 |
-
"--
|
145 |
-
|
146 |
-
|
147 |
-
help="Model used to compute score perplexity of generated text",
|
148 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
# Results
|
150 |
parser.add_argument(
|
151 |
"--repeat",
|
@@ -154,19 +462,19 @@ def create_args():
|
|
154 |
help="How many times to repeat for each set of parameters, prompts and messages",
|
155 |
)
|
156 |
parser.add_argument(
|
157 |
-
"--
|
158 |
type=str,
|
159 |
default=None,
|
160 |
-
help="Where to load
|
161 |
)
|
162 |
parser.add_argument(
|
163 |
-
"--
|
164 |
type=str,
|
165 |
default=None,
|
166 |
-
help="Where to save
|
167 |
)
|
168 |
parser.add_argument(
|
169 |
-
"--
|
170 |
)
|
171 |
parser.add_argument(
|
172 |
"--figs-dir",
|
@@ -178,268 +486,10 @@ def create_args():
|
|
178 |
return parser.parse_args()
|
179 |
|
180 |
|
181 |
-
def get_results(args, prompts, msgs):
|
182 |
-
model, tokenizer = ModelFactory.load_model(args.gen_model)
|
183 |
-
results = []
|
184 |
-
total_gen = (
|
185 |
-
len(prompts)
|
186 |
-
* int(args.deltas[2])
|
187 |
-
* len(args.bases)
|
188 |
-
* args.repeat
|
189 |
-
* sum([len(msgs[k]) for k in msgs])
|
190 |
-
)
|
191 |
-
|
192 |
-
with tqdm(total=total_gen, desc="Generating") as pbar:
|
193 |
-
for k in msgs:
|
194 |
-
msg_type = k
|
195 |
-
for msg in msgs[k]:
|
196 |
-
msg_bytes = (
|
197 |
-
msg.encode("ascii")
|
198 |
-
if k == "readable"
|
199 |
-
else base64.b64decode(msg)
|
200 |
-
)
|
201 |
-
for base in args.bases:
|
202 |
-
for delta in np.linspace(
|
203 |
-
args.deltas[0], args.deltas[1], int(args.deltas[2])
|
204 |
-
):
|
205 |
-
for prompt in prompts:
|
206 |
-
for _ in range(args.repeat):
|
207 |
-
text, msg_rate, tokens_info = generate(
|
208 |
-
tokenizer=tokenizer,
|
209 |
-
model=model,
|
210 |
-
prompt=prompt,
|
211 |
-
msg=msg_bytes,
|
212 |
-
start_pos_p=[0],
|
213 |
-
delta=delta,
|
214 |
-
msg_base=base,
|
215 |
-
seed_scheme="sha_left_hash",
|
216 |
-
window_length=1,
|
217 |
-
private_key=0,
|
218 |
-
min_new_tokens_ratio=1,
|
219 |
-
max_new_tokens_ratio=2,
|
220 |
-
num_beams=4,
|
221 |
-
repetition_penalty=1.5,
|
222 |
-
prompt_size=args.prompt_size,
|
223 |
-
)
|
224 |
-
results.append(
|
225 |
-
{
|
226 |
-
"msg_type": msg_type,
|
227 |
-
"delta": delta.item(),
|
228 |
-
"base": base,
|
229 |
-
"perplexity": ModelFactory.compute_perplexity(
|
230 |
-
args.judge_model, text
|
231 |
-
),
|
232 |
-
"msg_rate": msg_rate,
|
233 |
-
"msg_len": len(msg_bytes),
|
234 |
-
}
|
235 |
-
)
|
236 |
-
pbar.set_postfix(
|
237 |
-
{
|
238 |
-
"perplexity": results[-1]["perplexity"],
|
239 |
-
"msg_rate": results[-1]["msg_rate"],
|
240 |
-
"msg_len": len(msg_bytes),
|
241 |
-
"delta": delta.item(),
|
242 |
-
"base": base,
|
243 |
-
}
|
244 |
-
)
|
245 |
-
if (
|
246 |
-
len(results) + 1
|
247 |
-
) % args.results_save_freq == 0:
|
248 |
-
if args.results_save_file:
|
249 |
-
os.makedirs(
|
250 |
-
os.path.dirname(
|
251 |
-
args.results_save_file
|
252 |
-
),
|
253 |
-
exist_ok=True,
|
254 |
-
)
|
255 |
-
with open(
|
256 |
-
args.results_save_file, "w"
|
257 |
-
) as f:
|
258 |
-
json.dump(results, f)
|
259 |
-
print(
|
260 |
-
f"Saved results to {args.results_save_file}"
|
261 |
-
)
|
262 |
-
|
263 |
-
pbar.update()
|
264 |
-
return results
|
265 |
-
|
266 |
-
|
267 |
-
def process_results(results, save_dir):
|
268 |
-
data = {
|
269 |
-
"perplexities": {
|
270 |
-
"random": {},
|
271 |
-
"readable": {},
|
272 |
-
},
|
273 |
-
"msg_rates": {
|
274 |
-
"random": {},
|
275 |
-
"readable": {},
|
276 |
-
},
|
277 |
-
}
|
278 |
-
for r in results:
|
279 |
-
msg_type = r["msg_type"]
|
280 |
-
base = r["base"]
|
281 |
-
delta = r["delta"]
|
282 |
-
msg_rate = r["msg_rate"]
|
283 |
-
msg_len = r["msg_len"]
|
284 |
-
perplexity = r["perplexity"]
|
285 |
-
|
286 |
-
if (base, delta, msg_len) not in data["msg_rates"][msg_type]:
|
287 |
-
data["msg_rates"][msg_type][(base, delta, msg_len)] = []
|
288 |
-
data["msg_rates"][msg_type][(base, delta, msg_len)].append(msg_rate)
|
289 |
-
|
290 |
-
if (base, delta, msg_len) not in data["perplexities"][msg_type]:
|
291 |
-
data["perplexities"][msg_type][(base, delta, msg_len)] = []
|
292 |
-
data["perplexities"][msg_type][(base, delta, msg_len)].append(
|
293 |
-
perplexity
|
294 |
-
)
|
295 |
-
|
296 |
-
bases = {
|
297 |
-
"perplexities": {
|
298 |
-
"random": [],
|
299 |
-
"readable": [],
|
300 |
-
},
|
301 |
-
"msg_rates": {
|
302 |
-
"random": [],
|
303 |
-
"readable": [],
|
304 |
-
},
|
305 |
-
}
|
306 |
-
deltas = {
|
307 |
-
"perplexities": {
|
308 |
-
"random": [],
|
309 |
-
"readable": [],
|
310 |
-
},
|
311 |
-
"msg_rates": {
|
312 |
-
"random": [],
|
313 |
-
"readable": [],
|
314 |
-
},
|
315 |
-
}
|
316 |
-
msgs_lens = {
|
317 |
-
"perplexities": {
|
318 |
-
"random": [],
|
319 |
-
"readable": [],
|
320 |
-
},
|
321 |
-
"msg_rates": {
|
322 |
-
"random": [],
|
323 |
-
"readable": [],
|
324 |
-
},
|
325 |
-
}
|
326 |
-
values = {
|
327 |
-
"perplexities": {
|
328 |
-
"random": [],
|
329 |
-
"readable": [],
|
330 |
-
},
|
331 |
-
"msg_rates": {
|
332 |
-
"random": [],
|
333 |
-
"readable": [],
|
334 |
-
},
|
335 |
-
}
|
336 |
-
base_set = set()
|
337 |
-
delta_set = set()
|
338 |
-
msgs_lens_set = set()
|
339 |
-
for metric in data:
|
340 |
-
for msg_type in data[metric]:
|
341 |
-
for k in data[metric][msg_type]:
|
342 |
-
s = sum(data[metric][msg_type][k])
|
343 |
-
cnt = len(data[metric][msg_type][k])
|
344 |
-
data[metric][msg_type][k] = s / cnt
|
345 |
-
|
346 |
-
bases[metric][msg_type].append(k[0])
|
347 |
-
deltas[metric][msg_type].append(k[1])
|
348 |
-
msgs_lens[metric][msg_type].append(k[2])
|
349 |
-
values[metric][msg_type].append(s / cnt)
|
350 |
-
base_set.add(k[0])
|
351 |
-
delta_set.add(k[1])
|
352 |
-
msgs_lens_set.add(k[2])
|
353 |
-
|
354 |
-
for metric in data:
|
355 |
-
for msg_type in data[metric]:
|
356 |
-
bases[metric][msg_type] = np.array(
|
357 |
-
bases[metric][msg_type], dtype=np.int64
|
358 |
-
)
|
359 |
-
deltas[metric][msg_type] = np.array(
|
360 |
-
deltas[metric][msg_type], dtype=np.int64
|
361 |
-
)
|
362 |
-
msgs_lens[metric][msg_type] = np.array(
|
363 |
-
msgs_lens[metric][msg_type], dtype=np.int64
|
364 |
-
)
|
365 |
-
|
366 |
-
values[metric][msg_type] = np.array(
|
367 |
-
values[metric][msg_type], dtype=np.float64
|
368 |
-
)
|
369 |
-
|
370 |
-
os.makedirs(save_dir, exist_ok=True)
|
371 |
-
for metric in data:
|
372 |
-
for msg_type in data[metric]:
|
373 |
-
fig = plt.figure(dpi=300)
|
374 |
-
s = lambda x: 3.0 + x * (30 if metric == "msg_rates" else 10)
|
375 |
-
plt.scatter(
|
376 |
-
bases[metric][msg_type],
|
377 |
-
deltas[metric][msg_type],
|
378 |
-
s(values[metric][msg_type]),
|
379 |
-
)
|
380 |
-
plt.savefig(
|
381 |
-
os.path.join(save_dir, f"{metric}_{msg_type}_scatter.pdf"),
|
382 |
-
bbox_inches="tight",
|
383 |
-
)
|
384 |
-
plt.close(fig)
|
385 |
-
|
386 |
-
os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
|
387 |
-
for metric in data:
|
388 |
-
for msg_type in data[metric]:
|
389 |
-
fig = plt.figure(dpi=300)
|
390 |
-
for base_value in base_set:
|
391 |
-
deltas_avg = np.array(list(sorted(delta_set)))
|
392 |
-
values_avg = np.zeros_like(deltas_avg, dtype=np.float64)
|
393 |
-
for i in range(len(deltas_avg)):
|
394 |
-
mask = (deltas[metric][msg_type] == deltas_avg[i]) & (
|
395 |
-
bases[metric][msg_type] == base_value
|
396 |
-
)
|
397 |
-
values_avg[i] = np.mean(values[metric][msg_type][mask])
|
398 |
-
plt.plot(deltas_avg, values_avg, label=f"Base {base_value}")
|
399 |
-
|
400 |
-
plt.legend()
|
401 |
-
plt.savefig(
|
402 |
-
os.path.join(
|
403 |
-
save_dir,
|
404 |
-
f"delta_effect/{metric}_{msg_type}.pdf",
|
405 |
-
),
|
406 |
-
bbox_inches="tight",
|
407 |
-
)
|
408 |
-
plt.close(fig)
|
409 |
-
|
410 |
-
os.makedirs(os.path.join(save_dir, "msg_len_effect"), exist_ok=True)
|
411 |
-
for metric in data:
|
412 |
-
for msg_type in data[metric]:
|
413 |
-
fig = plt.figure(dpi=300)
|
414 |
-
for base_value in base_set:
|
415 |
-
msgs_lens_avg = np.array(sorted(list(msgs_lens_set)))
|
416 |
-
values_avg = np.zeros_like(msgs_lens_avg, dtype=np.float64)
|
417 |
-
for i in range(len(msgs_lens_avg)):
|
418 |
-
mask = (msgs_lens[metric][msg_type] == msgs_lens_avg[i]) & (
|
419 |
-
bases[metric][msg_type] == base_value
|
420 |
-
)
|
421 |
-
values_avg[i] = np.mean(values[metric][msg_type][mask])
|
422 |
-
|
423 |
-
plt.plot(msgs_lens_avg, values_avg, label=f"Base {base_value}")
|
424 |
-
|
425 |
-
plt.legend()
|
426 |
-
plt.savefig(
|
427 |
-
os.path.join(
|
428 |
-
save_dir,
|
429 |
-
f"msg_len_effect/{metric}_{msg_type}.pdf",
|
430 |
-
),
|
431 |
-
bbox_inches="tight",
|
432 |
-
)
|
433 |
-
plt.close(fig)
|
434 |
-
|
435 |
-
|
436 |
def main(args):
|
437 |
-
if not args.
|
438 |
-
|
439 |
-
|
440 |
-
args.prompt_size,
|
441 |
-
args.prompts_file if not args.overwrite else None,
|
442 |
-
)
|
443 |
|
444 |
msgs_lens = []
|
445 |
for i in np.linspace(
|
@@ -451,36 +501,44 @@ def main(args):
|
|
451 |
for _ in range(args.msgs_per_length):
|
452 |
msgs_lens.append(i)
|
453 |
|
454 |
-
msgs = load_msgs(
|
455 |
-
|
456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
)
|
458 |
-
|
459 |
-
if args.msgs_file:
|
460 |
-
if not os.path.isfile(args.msgs_file) or args.overwrite:
|
461 |
-
os.makedirs(os.path.dirname(args.msgs_file), exist_ok=True)
|
462 |
-
with open(args.msgs_file, "w") as f:
|
463 |
-
json.dump(msgs, f)
|
464 |
-
print(f"Saved messages to {args.msgs_file}")
|
465 |
-
if args.prompts_file:
|
466 |
-
if not os.path.isfile(args.prompts_file) or args.overwrite:
|
467 |
-
os.makedirs(os.path.dirname(args.prompts_file), exist_ok=True)
|
468 |
-
with open(args.prompts_file, "w") as f:
|
469 |
-
json.dump(prompts, f)
|
470 |
-
print(f"Saved prompts to {args.prompts_file}")
|
471 |
-
results = get_results(args, prompts, msgs)
|
472 |
else:
|
473 |
-
|
474 |
-
|
|
|
|
|
|
|
475 |
|
476 |
-
|
477 |
-
|
478 |
-
with open(args.results_save_file, "w") as f:
|
479 |
-
json.dump(results, f)
|
480 |
-
print(f"Saved results to {args.results_save_file}")
|
481 |
|
482 |
-
if args.figs_dir:
|
483 |
-
|
484 |
|
485 |
|
486 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
+
from datetime import datetime
|
3 |
+
from copy import deepcopy
|
4 |
import json
|
5 |
import base64
|
6 |
from argparse import ArgumentParser
|
|
|
18 |
rng.manual_seed(0)
|
19 |
|
20 |
|
21 |
+
def load_msgs(msg_lens: list[int]):
|
22 |
+
msgs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True)
|
24 |
iterator = iter(c4_en)
|
25 |
|
26 |
for length in tqdm(msg_lens, desc="Loading messages"):
|
27 |
random_msg = torch.randint(256, (length,), generator=rng)
|
28 |
+
msgs.append(["random", bytes(random_msg.tolist())])
|
|
|
|
|
|
|
29 |
|
30 |
while True:
|
31 |
readable_msg = next(iterator)["text"]
|
32 |
try:
|
33 |
+
msgs.append(["readable", readable_msg[:length].encode("ascii")])
|
34 |
break
|
35 |
except Exception as e:
|
36 |
continue
|
|
|
37 |
|
38 |
return msgs
|
39 |
|
40 |
|
41 |
+
def load_prompts(tokenizer, n: int, prompt_size: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
prompts = []
|
43 |
|
44 |
c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True)
|
|
|
47 |
with tqdm(total=n, desc="Loading prompts") as pbar:
|
48 |
while len(prompts) < n:
|
49 |
text = next(iterator)["text"]
|
50 |
+
input_ids = tokenizer.encode(text, return_tensors="pt")
|
51 |
+
if input_ids.size(1) < prompt_size:
|
52 |
continue
|
53 |
+
truncated_text = tokenizer.batch_decode(input_ids[:, :prompt_size])[
|
54 |
+
0
|
55 |
+
]
|
56 |
+
prompts.append(truncated_text)
|
57 |
pbar.update()
|
58 |
|
59 |
return prompts
|
60 |
|
61 |
|
62 |
+
class AnalyseProcessor(object):
|
63 |
+
params_names = [
|
64 |
+
"msgs",
|
65 |
+
"bases",
|
66 |
+
"deltas",
|
67 |
+
]
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
save_file: str,
|
72 |
+
save_freq: int | None = None,
|
73 |
+
gen_model: str | None = None,
|
74 |
+
judge_model: str | None = None,
|
75 |
+
msgs: list[bytes] | None = None,
|
76 |
+
bases: list[int] | None = None,
|
77 |
+
deltas: list[float] | None = None,
|
78 |
+
prompts: list[str] | None = None,
|
79 |
+
repeat: int = 1,
|
80 |
+
gen_params: dict | None = None,
|
81 |
+
batch_size: int = 1,
|
82 |
+
):
|
83 |
+
self.save_file = save_file
|
84 |
+
self.save_freq = save_freq
|
85 |
+
self.data = {
|
86 |
+
"params": {
|
87 |
+
"gen_model": gen_model,
|
88 |
+
"judge_model": judge_model,
|
89 |
+
"ptrs": {
|
90 |
+
"msgs": 0,
|
91 |
+
"bases": 0,
|
92 |
+
"deltas": 0,
|
93 |
+
},
|
94 |
+
"values": {
|
95 |
+
"msgs": msgs,
|
96 |
+
"bases": bases,
|
97 |
+
"deltas": deltas,
|
98 |
+
},
|
99 |
+
"prompts": prompts,
|
100 |
+
"batch_size": batch_size,
|
101 |
+
"repeat": repeat,
|
102 |
+
"gen": gen_params,
|
103 |
+
},
|
104 |
+
"results": [],
|
105 |
+
}
|
106 |
+
self.__pbar = None
|
107 |
+
self.last_saved = None
|
108 |
+
self.skip_first = False
|
109 |
+
|
110 |
+
def run(self, depth=0):
|
111 |
+
if self.__pbar is None:
|
112 |
+
total = 1
|
113 |
+
for v in self.data["params"]["values"].keys():
|
114 |
+
if v is None:
|
115 |
+
raise RuntimeError(f"values must not be None when running")
|
116 |
+
|
117 |
+
initial = 0
|
118 |
+
for param_name in self.params_names[::-1]:
|
119 |
+
initial += total * self.data["params"]["ptrs"][param_name]
|
120 |
+
total *= len(self.data["params"]["values"][param_name])
|
121 |
+
|
122 |
+
if self.skip_first:
|
123 |
+
initial += 1
|
124 |
+
|
125 |
+
self.__pbar = tqdm(
|
126 |
+
desc="Generating",
|
127 |
+
total=total,
|
128 |
+
initial=initial,
|
129 |
+
)
|
130 |
+
|
131 |
+
if depth < len(self.params_names):
|
132 |
+
param_name = self.params_names[depth]
|
133 |
+
|
134 |
+
while self.data["params"]["ptrs"][param_name] < len(
|
135 |
+
self.data["params"]["values"][param_name]
|
136 |
+
):
|
137 |
+
self.run(depth + 1)
|
138 |
+
self.data["params"]["ptrs"][param_name] = (
|
139 |
+
self.data["params"]["ptrs"][param_name] + 1
|
140 |
+
)
|
141 |
+
|
142 |
+
self.data["params"]["ptrs"][param_name] = 0
|
143 |
+
if depth == 0:
|
144 |
+
self.save_data(self.save_file)
|
145 |
+
else:
|
146 |
+
if self.skip_first:
|
147 |
+
self.skip_first = False
|
148 |
+
return
|
149 |
+
prompts = self.data["params"]["prompts"]
|
150 |
+
|
151 |
+
msg_ptr = self.data["params"]["ptrs"]["msgs"]
|
152 |
+
msg_type, msg = self.data["params"]["values"]["msgs"][msg_ptr]
|
153 |
+
|
154 |
+
base_ptr = self.data["params"]["ptrs"]["bases"]
|
155 |
+
base = self.data["params"]["values"]["bases"][base_ptr]
|
156 |
+
|
157 |
+
delta_ptr = self.data["params"]["ptrs"]["deltas"]
|
158 |
+
delta = self.data["params"]["values"]["deltas"][delta_ptr]
|
159 |
+
|
160 |
+
model, tokenizer = ModelFactory.load_model(
|
161 |
+
self.data["params"]["gen_model"]
|
162 |
+
)
|
163 |
+
l = 0
|
164 |
+
while l < len(prompts):
|
165 |
+
start = datetime.now()
|
166 |
+
r = l + self.data["params"]["batch_size"]
|
167 |
+
r = min(r, len(prompts))
|
168 |
+
|
169 |
+
texts, msgs_rates, _ = generate(
|
170 |
+
model=model,
|
171 |
+
tokenizer=tokenizer,
|
172 |
+
prompt=prompts[l:r],
|
173 |
+
msg=msg,
|
174 |
+
msg_base=base,
|
175 |
+
delta=delta,
|
176 |
+
**self.data["params"]["gen"],
|
177 |
+
)
|
178 |
+
end = datetime.now()
|
179 |
+
for i in range(len(texts)):
|
180 |
+
prompt_ptr = l + i
|
181 |
+
text = texts[i]
|
182 |
+
msg_rate = msgs_rates[i]
|
183 |
+
self.data["results"].append(
|
184 |
+
{
|
185 |
+
"ptrs": {
|
186 |
+
"prompts": prompt_ptr,
|
187 |
+
"msgs": msg_ptr,
|
188 |
+
"bases": base_ptr,
|
189 |
+
"deltas": delta_ptr,
|
190 |
+
},
|
191 |
+
"perplexity": ModelFactory.compute_perplexity(
|
192 |
+
self.data["params"]["judge_model"], text
|
193 |
+
),
|
194 |
+
"text": text,
|
195 |
+
"msg_rate": msg_rate,
|
196 |
+
"run_time (ms)": (end - start).microseconds
|
197 |
+
/ len(texts),
|
198 |
+
}
|
199 |
+
)
|
200 |
+
l += self.data["params"]["batch_size"]
|
201 |
+
|
202 |
+
postfix = {
|
203 |
+
"base": base,
|
204 |
+
"msg_len": len(msg),
|
205 |
+
"delta": delta,
|
206 |
+
}
|
207 |
+
self.__pbar.refresh()
|
208 |
+
if self.save_freq and (self.__pbar.n + 1) % self.save_freq == 0:
|
209 |
+
self.save_data(self.save_file)
|
210 |
+
|
211 |
+
if self.last_saved is not None:
|
212 |
+
seconds = (datetime.now() - self.last_saved).seconds
|
213 |
+
minutes = seconds // 60
|
214 |
+
hours = minutes // 60
|
215 |
+
minutes %= 60
|
216 |
+
seconds %= 60
|
217 |
+
postfix["last_saved"] = f"{hours}:{minutes}:{seconds} ago"
|
218 |
+
|
219 |
+
self.__pbar.set_postfix(postfix)
|
220 |
+
self.__pbar.update()
|
221 |
+
|
222 |
+
def __get_mean(self, ptrs: dict, value_name: str):
|
223 |
+
s = 0
|
224 |
+
cnt = 0
|
225 |
+
for r in self.data["results"]:
|
226 |
+
msg_type, msg = self.data["params"]["values"]["msgs"][
|
227 |
+
r["ptrs"]["msgs"]
|
228 |
+
]
|
229 |
+
valid = True
|
230 |
+
for k in ptrs:
|
231 |
+
if (
|
232 |
+
(k in r["ptrs"] and r["ptrs"][k] != ptrs[k])
|
233 |
+
or (k == "msg_len" and len(msg) != ptrs[k])
|
234 |
+
or (k == "msg_type" and msg_type != ptrs[k])
|
235 |
+
):
|
236 |
+
valid = False
|
237 |
+
break
|
238 |
+
|
239 |
+
if valid:
|
240 |
+
s += r[value_name]
|
241 |
+
cnt += 1
|
242 |
+
if cnt == 0:
|
243 |
+
cnt = 1
|
244 |
+
return s / cnt
|
245 |
+
|
246 |
+
def plot(self, figs_dir: str):
|
247 |
+
os.makedirs(figs_dir, exist_ok=True)
|
248 |
+
msg_set = set()
|
249 |
+
for msg_type, msg in self.data["params"]["values"]["msgs"]:
|
250 |
+
msg_set.add((msg_type, len(msg)))
|
251 |
+
msg_set = sorted(msg_set)
|
252 |
+
|
253 |
+
# Delta effect
|
254 |
+
os.makedirs(os.path.join(figs_dir, "delta_effect"), exist_ok=True)
|
255 |
+
for value_name in ["perplexity", "msg_rate"]:
|
256 |
+
fig = plt.figure(dpi=300)
|
257 |
+
for base_ptr, base in enumerate(
|
258 |
+
self.data["params"]["values"]["bases"]
|
259 |
+
):
|
260 |
+
for msg_type, msg_len in msg_set:
|
261 |
+
x = []
|
262 |
+
y = []
|
263 |
+
for delta_ptr, delta in enumerate(
|
264 |
+
self.data["params"]["values"]["deltas"]
|
265 |
+
):
|
266 |
+
x.append(delta)
|
267 |
+
y.append(
|
268 |
+
self.__get_mean(
|
269 |
+
ptrs={
|
270 |
+
"bases": base_ptr,
|
271 |
+
"msg_type": msg_type,
|
272 |
+
"msg_len": msg_len,
|
273 |
+
"deltas": delta_ptr,
|
274 |
+
},
|
275 |
+
value_name=value_name,
|
276 |
+
)
|
277 |
+
)
|
278 |
+
plt.plot(
|
279 |
+
x,
|
280 |
+
y,
|
281 |
+
label=f"B={base}, msg_type={msg_type}, msg_len={msg_len}",
|
282 |
+
)
|
283 |
+
plt.ylim(ymin=0)
|
284 |
+
plt.legend()
|
285 |
+
plt.savefig(
|
286 |
+
os.path.join(figs_dir, "delta_effect", f"{value_name}.pdf"),
|
287 |
+
bbox_inches="tight",
|
288 |
+
)
|
289 |
+
plt.close(fig)
|
290 |
+
|
291 |
+
# Message length effect
|
292 |
+
os.makedirs(os.path.join(figs_dir, "msg_len_effect"), exist_ok=True)
|
293 |
+
for value_name in ["perplexity", "msg_rate"]:
|
294 |
+
fig = plt.figure(dpi=300)
|
295 |
+
for base_ptr, base in enumerate(
|
296 |
+
self.data["params"]["values"]["bases"]
|
297 |
+
):
|
298 |
+
for delta_ptr, delta in enumerate(
|
299 |
+
self.data["params"]["values"]["deltas"]
|
300 |
+
):
|
301 |
+
x = {}
|
302 |
+
y = {}
|
303 |
+
for msg_type, msg_len in msg_set:
|
304 |
+
if msg_type not in x:
|
305 |
+
x[msg_type] = []
|
306 |
+
if msg_type not in y:
|
307 |
+
y[msg_type] = []
|
308 |
+
x[msg_type].append(msg_len)
|
309 |
+
y[msg_type].append(
|
310 |
+
self.__get_mean(
|
311 |
+
ptrs={
|
312 |
+
"bases": base_ptr,
|
313 |
+
"msg_type": msg_type,
|
314 |
+
"msg_len": msg_len,
|
315 |
+
"deltas": delta_ptr,
|
316 |
+
},
|
317 |
+
value_name=value_name,
|
318 |
+
)
|
319 |
+
)
|
320 |
+
for msg_type in x:
|
321 |
+
plt.plot(
|
322 |
+
x[msg_type],
|
323 |
+
y[msg_type],
|
324 |
+
label=f"B={base}, msg_type={msg_type}, delta={delta}",
|
325 |
+
)
|
326 |
+
plt.ylim(ymin=0)
|
327 |
+
plt.legend()
|
328 |
+
plt.savefig(
|
329 |
+
os.path.join(figs_dir, "msg_len_effect", f"{value_name}.pdf"),
|
330 |
+
bbox_inches="tight",
|
331 |
+
)
|
332 |
+
plt.close(fig)
|
333 |
+
print(f"Saved figures to {figs_dir}")
|
334 |
+
|
335 |
+
def save_data(self, file_name: str):
|
336 |
+
if file_name is None:
|
337 |
+
return
|
338 |
+
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
339 |
+
data = deepcopy(self.data)
|
340 |
+
for i in range(len(data["params"]["values"]["msgs"])):
|
341 |
+
msg_type, msg = data["params"]["values"]["msgs"][i]
|
342 |
+
if msg_type == "random":
|
343 |
+
str_msg = base64.b64encode(msg).decode("ascii")
|
344 |
+
else:
|
345 |
+
str_msg = msg.decode("ascii")
|
346 |
+
data["params"]["values"]["msgs"][i] = [msg_type, str_msg]
|
347 |
+
|
348 |
+
with open(file_name, "w") as f:
|
349 |
+
json.dump(data, f, indent=2)
|
350 |
+
if self.__pbar is None:
|
351 |
+
print(f"Saved AnalyseProcessor data to {file_name}")
|
352 |
+
else:
|
353 |
+
self.last_saved = datetime.now()
|
354 |
+
|
355 |
+
def load_data(self, file_name: str):
|
356 |
+
with open(file_name, "r") as f:
|
357 |
+
self.data = json.load(f)
|
358 |
+
for i in range(len(self.data["params"]["values"]["msgs"])):
|
359 |
+
msg_type, str_msg = self.data["params"]["values"]["msgs"][i]
|
360 |
+
if msg_type == "random":
|
361 |
+
msg = base64.b64decode(str_msg)
|
362 |
+
else:
|
363 |
+
msg = str_msg.encode("ascii")
|
364 |
+
self.data["params"]["values"]["msgs"][i] = [msg_type, msg]
|
365 |
+
|
366 |
+
self.skip_first = len(self.data["results"]) > 0
|
367 |
+
self.__pbar = None
|
368 |
+
|
369 |
+
|
370 |
def create_args():
|
371 |
parser = ArgumentParser()
|
372 |
|
|
|
396 |
parser.add_argument(
|
397 |
"--num-prompts",
|
398 |
type=int,
|
399 |
+
default=10,
|
400 |
help="Number of prompts",
|
401 |
)
|
402 |
parser.add_argument(
|
|
|
419 |
default="gpt2",
|
420 |
help="Model used to generate",
|
421 |
)
|
422 |
+
parser.add_argument(
|
423 |
+
"--judge-model",
|
424 |
+
type=str,
|
425 |
+
default="gpt2",
|
426 |
+
help="Model used to compute score perplexity of generated text",
|
427 |
+
)
|
428 |
parser.add_argument(
|
429 |
"--deltas",
|
430 |
nargs=3,
|
|
|
437 |
type=int,
|
438 |
help="Bases used in base encoding",
|
439 |
)
|
440 |
+
|
441 |
+
# Generate parameters
|
442 |
parser.add_argument(
|
443 |
+
"--do-sample",
|
444 |
+
action="store_true",
|
445 |
+
help="Whether to use sample or greedy search",
|
|
|
446 |
)
|
447 |
+
parser.add_argument(
|
448 |
+
"--num-beams", type=int, default=1, help="How many beams to use"
|
449 |
+
)
|
450 |
+
parser.add_argument(
|
451 |
+
"--batch-size",
|
452 |
+
type=int,
|
453 |
+
default=1,
|
454 |
+
help="Batch size used for generating",
|
455 |
+
)
|
456 |
+
|
457 |
# Results
|
458 |
parser.add_argument(
|
459 |
"--repeat",
|
|
|
462 |
help="How many times to repeat for each set of parameters, prompts and messages",
|
463 |
)
|
464 |
parser.add_argument(
|
465 |
+
"--load-file",
|
466 |
type=str,
|
467 |
default=None,
|
468 |
+
help="Where to load data for AnalyseProcessor",
|
469 |
)
|
470 |
parser.add_argument(
|
471 |
+
"--save-file",
|
472 |
type=str,
|
473 |
default=None,
|
474 |
+
help="Where to save data for AnalyseProcessor",
|
475 |
)
|
476 |
parser.add_argument(
|
477 |
+
"--save-freq", type=int, default=100, help="Save frequency"
|
478 |
)
|
479 |
parser.add_argument(
|
480 |
"--figs-dir",
|
|
|
486 |
return parser.parse_args()
|
487 |
|
488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
def main(args):
|
490 |
+
if not args.load_file:
|
491 |
+
model, tokenizer = ModelFactory.load_model(args.gen_model)
|
492 |
+
prompts = load_prompts(tokenizer, args.num_prompts, args.prompt_size)
|
|
|
|
|
|
|
493 |
|
494 |
msgs_lens = []
|
495 |
for i in np.linspace(
|
|
|
501 |
for _ in range(args.msgs_per_length):
|
502 |
msgs_lens.append(i)
|
503 |
|
504 |
+
msgs = load_msgs(msgs_lens)
|
505 |
+
|
506 |
+
processor = AnalyseProcessor(
|
507 |
+
save_file=args.save_file,
|
508 |
+
save_freq=args.save_freq,
|
509 |
+
gen_model=args.gen_model,
|
510 |
+
judge_model=args.judge_model,
|
511 |
+
msgs=msgs,
|
512 |
+
bases=args.bases,
|
513 |
+
deltas=np.linspace(
|
514 |
+
args.deltas[0], args.deltas[1], int(args.deltas[2])
|
515 |
+
).tolist(),
|
516 |
+
prompts=prompts,
|
517 |
+
batch_size=args.batch_size,
|
518 |
+
gen_params=dict(
|
519 |
+
start_pos_p=[0],
|
520 |
+
seed_scheme="dummy_hash",
|
521 |
+
window_length=1,
|
522 |
+
min_new_tokens_ratio=1,
|
523 |
+
max_new_tokens_ratio=1,
|
524 |
+
do_sample=args.do_sample,
|
525 |
+
num_beams=args.num_beams,
|
526 |
+
repetition_penalty=1.0,
|
527 |
+
),
|
528 |
)
|
529 |
+
processor.save_data(args.save_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
else:
|
531 |
+
processor = AnalyseProcessor(
|
532 |
+
save_file=args.save_file,
|
533 |
+
save_freq=args.save_freq,
|
534 |
+
)
|
535 |
+
processor.load_data(args.load_file)
|
536 |
|
537 |
+
processor.run()
|
538 |
+
processor.plot(args.figs_dir)
|
|
|
|
|
|
|
539 |
|
540 |
+
# if args.figs_dir:
|
541 |
+
# process_results(results, args.figs_dir)
|
542 |
|
543 |
|
544 |
if __name__ == "__main__":
|
api.py
CHANGED
@@ -34,7 +34,7 @@ async def encrypt_api(
|
|
34 |
):
|
35 |
byte_msg = base64.b64decode(body.msg)
|
36 |
model, tokenizer = ModelFactory.load_model(body.gen_model)
|
37 |
-
|
38 |
tokenizer=tokenizer,
|
39 |
model=model,
|
40 |
prompt=body.prompt,
|
@@ -49,7 +49,11 @@ async def encrypt_api(
|
|
49 |
num_beams=body.num_beams,
|
50 |
repetition_penalty=body.repetition_penalty,
|
51 |
)
|
52 |
-
return {
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
@app.post(
|
@@ -114,6 +118,7 @@ async def default_config():
|
|
114 |
"max_new_tokens_ratio": GlobalConfig.get(
|
115 |
"encrypt.default", "max_new_tokens_ratio"
|
116 |
),
|
|
|
117 |
"num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
|
118 |
"repetition_penalty": GlobalConfig.get(
|
119 |
"encrypt.default", "repetition_penalty"
|
|
|
34 |
):
|
35 |
byte_msg = base64.b64decode(body.msg)
|
36 |
model, tokenizer = ModelFactory.load_model(body.gen_model)
|
37 |
+
texts, msgs_rates, tokens_infos = generate(
|
38 |
tokenizer=tokenizer,
|
39 |
model=model,
|
40 |
prompt=body.prompt,
|
|
|
49 |
num_beams=body.num_beams,
|
50 |
repetition_penalty=body.repetition_penalty,
|
51 |
)
|
52 |
+
return {
|
53 |
+
"texts": texts,
|
54 |
+
"msgs_rates": msgs_rates,
|
55 |
+
"tokens_info": tokens_infos,
|
56 |
+
}
|
57 |
|
58 |
|
59 |
@app.post(
|
|
|
118 |
"max_new_tokens_ratio": GlobalConfig.get(
|
119 |
"encrypt.default", "max_new_tokens_ratio"
|
120 |
),
|
121 |
+
"do_sample": GlobalConfig.get("encrypt.default", "do_sample"),
|
122 |
"num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
|
123 |
"repetition_penalty": GlobalConfig.get(
|
124 |
"encrypt.default", "repetition_penalty"
|
config.ini
CHANGED
@@ -34,7 +34,8 @@ window_length = int:1
|
|
34 |
private_key = int:0
|
35 |
min_new_tokens_ratio = float:1.0
|
36 |
max_new_tokens_ratio = float:2.0
|
37 |
-
|
|
|
38 |
repetition_penalty = float:1.0
|
39 |
|
40 |
[decrypt.default]
|
|
|
34 |
private_key = int:0
|
35 |
min_new_tokens_ratio = float:1.0
|
36 |
max_new_tokens_ratio = float:2.0
|
37 |
+
do_sample = bool:0
|
38 |
+
num_beams = int:1
|
39 |
repetition_penalty = float:1.0
|
40 |
|
41 |
[decrypt.default]
|
demo.py
CHANGED
@@ -17,12 +17,14 @@ def enc_fn(
|
|
17 |
seed_scheme: str,
|
18 |
window_length: int,
|
19 |
private_key: int,
|
|
|
|
|
20 |
max_new_tokens_ratio: float,
|
21 |
num_beams: int,
|
22 |
repetition_penalty: float,
|
23 |
):
|
24 |
model, tokenizer = ModelFactory.load_model(gen_model)
|
25 |
-
|
26 |
tokenizer=tokenizer,
|
27 |
model=model,
|
28 |
prompt=prompt,
|
@@ -33,12 +35,14 @@ def enc_fn(
|
|
33 |
seed_scheme=seed_scheme,
|
34 |
window_length=window_length,
|
35 |
private_key=private_key,
|
|
|
|
|
36 |
max_new_tokens_ratio=max_new_tokens_ratio,
|
37 |
num_beams=num_beams,
|
38 |
repetition_penalty=repetition_penalty,
|
39 |
)
|
40 |
highlight_base = []
|
41 |
-
for token in
|
42 |
stat = None
|
43 |
if token["base_msg"] != -1:
|
44 |
if token["base_msg"] == token["base_enc"]:
|
@@ -48,8 +52,8 @@ def enc_fn(
|
|
48 |
highlight_base.append((repr(token["token"])[1:-1], stat))
|
49 |
|
50 |
highlight_byte = []
|
51 |
-
for i, token in enumerate(
|
52 |
-
if i == 0 or
|
53 |
stat = None
|
54 |
if token["byte_msg"] != -1:
|
55 |
if token["byte_msg"] == token["byte_enc"]:
|
@@ -60,7 +64,12 @@ def enc_fn(
|
|
60 |
else:
|
61 |
highlight_byte[-1][0] += repr(token["token"])[1:-1]
|
62 |
|
63 |
-
return
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
def dec_fn(
|
@@ -108,13 +117,21 @@ if __name__ == "__main__":
|
|
108 |
int(GlobalConfig.get("encrypt.default", "window_length"))
|
109 |
),
|
110 |
gr.Number(int(GlobalConfig.get("encrypt.default", "private_key"))),
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
gr.Number(
|
112 |
float(
|
113 |
GlobalConfig.get("encrypt.default", "max_new_tokens_ratio")
|
114 |
)
|
115 |
),
|
116 |
gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
|
117 |
-
gr.Number(
|
|
|
|
|
118 |
],
|
119 |
outputs=[
|
120 |
gr.Textbox(
|
|
|
17 |
seed_scheme: str,
|
18 |
window_length: int,
|
19 |
private_key: int,
|
20 |
+
do_sample: bool,
|
21 |
+
min_new_tokens_ratio: float,
|
22 |
max_new_tokens_ratio: float,
|
23 |
num_beams: int,
|
24 |
repetition_penalty: float,
|
25 |
):
|
26 |
model, tokenizer = ModelFactory.load_model(gen_model)
|
27 |
+
texts, msgs_rates, tokens_infos = generate(
|
28 |
tokenizer=tokenizer,
|
29 |
model=model,
|
30 |
prompt=prompt,
|
|
|
35 |
seed_scheme=seed_scheme,
|
36 |
window_length=window_length,
|
37 |
private_key=private_key,
|
38 |
+
do_sample=do_sample,
|
39 |
+
min_new_tokens_ratio=min_new_tokens_ratio,
|
40 |
max_new_tokens_ratio=max_new_tokens_ratio,
|
41 |
num_beams=num_beams,
|
42 |
repetition_penalty=repetition_penalty,
|
43 |
)
|
44 |
highlight_base = []
|
45 |
+
for token in tokens_infos[0]:
|
46 |
stat = None
|
47 |
if token["base_msg"] != -1:
|
48 |
if token["base_msg"] == token["base_enc"]:
|
|
|
52 |
highlight_base.append((repr(token["token"])[1:-1], stat))
|
53 |
|
54 |
highlight_byte = []
|
55 |
+
for i, token in enumerate(tokens_infos[0]):
|
56 |
+
if i == 0 or tokens_infos[0][i - 1]["byte_id"] != token["byte_id"]:
|
57 |
stat = None
|
58 |
if token["byte_msg"] != -1:
|
59 |
if token["byte_msg"] == token["byte_enc"]:
|
|
|
64 |
else:
|
65 |
highlight_byte[-1][0] += repr(token["token"])[1:-1]
|
66 |
|
67 |
+
return (
|
68 |
+
texts[0],
|
69 |
+
highlight_base,
|
70 |
+
highlight_byte,
|
71 |
+
round(msgs_rates[0] * 100, 2),
|
72 |
+
)
|
73 |
|
74 |
|
75 |
def dec_fn(
|
|
|
117 |
int(GlobalConfig.get("encrypt.default", "window_length"))
|
118 |
),
|
119 |
gr.Number(int(GlobalConfig.get("encrypt.default", "private_key"))),
|
120 |
+
gr.Number(bool(GlobalConfig.get("encrypt.default", "do_sample"))),
|
121 |
+
gr.Number(
|
122 |
+
float(
|
123 |
+
GlobalConfig.get("encrypt.default", "min_new_tokens_ratio")
|
124 |
+
)
|
125 |
+
),
|
126 |
gr.Number(
|
127 |
float(
|
128 |
GlobalConfig.get("encrypt.default", "max_new_tokens_ratio")
|
129 |
)
|
130 |
),
|
131 |
gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
|
132 |
+
gr.Number(
|
133 |
+
float(GlobalConfig.get("encrypt.default", "repetition_penalty"))
|
134 |
+
),
|
135 |
],
|
136 |
outputs=[
|
137 |
gr.Textbox(
|
global_config.py
CHANGED
@@ -9,7 +9,7 @@ class GlobalConfig:
|
|
9 |
|
10 |
@classmethod
|
11 |
def get_section(cls, section_name):
|
12 |
-
if section_name in cls.config
|
13 |
return cls.config[section_name].keys()
|
14 |
else:
|
15 |
return None
|
@@ -27,6 +27,9 @@ class GlobalConfig:
|
|
27 |
value = float(value)
|
28 |
elif type_name == "int":
|
29 |
value = int(value)
|
|
|
|
|
|
|
30 |
return value
|
31 |
else:
|
32 |
return None
|
|
|
9 |
|
10 |
@classmethod
|
11 |
def get_section(cls, section_name):
|
12 |
+
if section_name in cls.config:
|
13 |
return cls.config[section_name].keys()
|
14 |
else:
|
15 |
return None
|
|
|
27 |
value = float(value)
|
28 |
elif type_name == "int":
|
29 |
value = int(value)
|
30 |
+
elif type_name == "bool":
|
31 |
+
value = bool(value)
|
32 |
+
|
33 |
return value
|
34 |
else:
|
35 |
return None
|
main.py
CHANGED
@@ -61,6 +61,17 @@ def create_args():
|
|
61 |
default=GlobalConfig.get("encrypt.default", "num_beams"),
|
62 |
help="Number of beams used in beam search",
|
63 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
parser.add_argument(
|
65 |
"--max-new-tokens-ratio",
|
66 |
type=float,
|
@@ -183,6 +194,8 @@ def main(args):
|
|
183 |
window_length=args.window_length,
|
184 |
salt_key=args.salt_key,
|
185 |
private_key=args.private_key,
|
|
|
|
|
186 |
max_new_tokens_ratio=args.max_new_tokens_ratio,
|
187 |
num_beams=args.num_beams,
|
188 |
)
|
|
|
61 |
default=GlobalConfig.get("encrypt.default", "num_beams"),
|
62 |
help="Number of beams used in beam search",
|
63 |
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--do-sample",
|
66 |
+
action="store_true",
|
67 |
+
help="Whether to do sample or greedy search",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--min-new-tokens-ratio",
|
71 |
+
type=float,
|
72 |
+
default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
|
73 |
+
help="Ratio of min new tokens to minimum tokens required to hide message",
|
74 |
+
)
|
75 |
parser.add_argument(
|
76 |
"--max-new-tokens-ratio",
|
77 |
type=float,
|
|
|
194 |
window_length=args.window_length,
|
195 |
salt_key=args.salt_key,
|
196 |
private_key=args.private_key,
|
197 |
+
do_sample=args.do_sample,
|
198 |
+
min_new_tokens_ratio=args.min_new_tokens_ratio,
|
199 |
max_new_tokens_ratio=args.max_new_tokens_ratio,
|
200 |
num_beams=args.num_beams,
|
201 |
)
|
processors.py
CHANGED
@@ -116,10 +116,10 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
116 |
delta: bias add to scores of token in valid list.
|
117 |
"""
|
118 |
super().__init__(*args, **kwargs)
|
119 |
-
if prompt_ids.size(0) != 1:
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
self.prompt_size = prompt_ids.size(1)
|
125 |
self.start_pos = start_pos
|
@@ -192,7 +192,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
192 |
return base_msg, byte_msg, base_enc_msg, byte_enc_msg
|
193 |
|
194 |
def validate(self, input_ids_batch: torch.Tensor):
|
195 |
-
|
196 |
tokens_infos = []
|
197 |
for input_ids in input_ids_batch:
|
198 |
# Initialization
|
@@ -214,7 +214,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
214 |
for i in range(min(len(enc_msg), len(self.raw_msg))):
|
215 |
if self.raw_msg[i] == enc_msg[i]:
|
216 |
cnt += 1
|
217 |
-
|
218 |
|
219 |
base_msg, byte_msg, base_enc_msg, byte_enc_msg = (
|
220 |
self.__map_input_ids(input_ids, base_arr, byte_arr)
|
@@ -234,7 +234,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
234 |
)
|
235 |
tokens_infos.append(tokens)
|
236 |
|
237 |
-
return
|
238 |
|
239 |
|
240 |
class DecryptorProcessor(BaseProcessor):
|
|
|
116 |
delta: bias add to scores of token in valid list.
|
117 |
"""
|
118 |
super().__init__(*args, **kwargs)
|
119 |
+
# if prompt_ids.size(0) != 1:
|
120 |
+
# raise RuntimeError(
|
121 |
+
# "EncryptorLogitsProcessor does not support multiple prompts input."
|
122 |
+
# )
|
123 |
|
124 |
self.prompt_size = prompt_ids.size(1)
|
125 |
self.start_pos = start_pos
|
|
|
192 |
return base_msg, byte_msg, base_enc_msg, byte_enc_msg
|
193 |
|
194 |
def validate(self, input_ids_batch: torch.Tensor):
|
195 |
+
msgs_rates = []
|
196 |
tokens_infos = []
|
197 |
for input_ids in input_ids_batch:
|
198 |
# Initialization
|
|
|
214 |
for i in range(min(len(enc_msg), len(self.raw_msg))):
|
215 |
if self.raw_msg[i] == enc_msg[i]:
|
216 |
cnt += 1
|
217 |
+
msgs_rates.append(cnt / len(self.raw_msg))
|
218 |
|
219 |
base_msg, byte_msg, base_enc_msg, byte_enc_msg = (
|
220 |
self.__map_input_ids(input_ids, base_arr, byte_arr)
|
|
|
234 |
)
|
235 |
tokens_infos.append(tokens)
|
236 |
|
237 |
+
return msgs_rates, tokens_infos
|
238 |
|
239 |
|
240 |
class DecryptorProcessor(BaseProcessor):
|
schemes.py
CHANGED
@@ -12,7 +12,7 @@ with open("resources/examples.json", "r") as f:
|
|
12 |
|
13 |
|
14 |
class EncryptionBody(BaseModel):
|
15 |
-
prompt: str = Field(title="Prompt used to generate text")
|
16 |
msg: str = Field(title="Message wanted to hide")
|
17 |
gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
|
18 |
default=GlobalConfig.get("encrypt.default", "gen_model"),
|
|
|
12 |
|
13 |
|
14 |
class EncryptionBody(BaseModel):
|
15 |
+
prompt: str | list[str] = Field(title="Prompt used to generate text")
|
16 |
msg: str = Field(title="Message wanted to hide")
|
17 |
gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
|
18 |
default=GlobalConfig.get("encrypt.default", "gen_model"),
|
seed_schemes/dummy_hash.py
CHANGED
@@ -8,5 +8,7 @@ class DummyHash(SeedScheme):
|
|
8 |
pass
|
9 |
|
10 |
def __call__(self, input_ids: torch.Tensor):
|
|
|
|
|
11 |
return int(input_ids[-1].item())
|
12 |
|
|
|
8 |
pass
|
9 |
|
10 |
def __call__(self, input_ids: torch.Tensor):
|
11 |
+
if input_ids.size(0) == 0:
|
12 |
+
return 0
|
13 |
return int(input_ids[-1].item())
|
14 |
|
stegno.py
CHANGED
@@ -9,7 +9,7 @@ from processors import EncryptorLogitsProcessor, DecryptorProcessor
|
|
9 |
def generate(
|
10 |
tokenizer,
|
11 |
model,
|
12 |
-
prompt: str,
|
13 |
msg: bytes,
|
14 |
start_pos_p: list[int],
|
15 |
delta: float,
|
@@ -20,9 +20,10 @@ def generate(
|
|
20 |
private_key: Union[int, None] = None,
|
21 |
min_new_tokens_ratio: float = 1,
|
22 |
max_new_tokens_ratio: float = 2,
|
23 |
-
|
|
|
24 |
repetition_penalty: float = 1.0,
|
25 |
-
|
26 |
):
|
27 |
"""
|
28 |
Generate the sequence containing the hidden data.
|
@@ -46,12 +47,13 @@ def generate(
|
|
46 |
start_pos_p[0], start_pos_p[1] + 1, (1,)
|
47 |
).item()
|
48 |
start_pos = int(start_pos) + window_length
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
|
51 |
-
if prompt_size == -1:
|
52 |
-
prompt_size = tokenized_input.input_ids.size(1)
|
53 |
logits_processor = EncryptorLogitsProcessor(
|
54 |
-
prompt_ids=tokenized_input.input_ids
|
55 |
msg=msg,
|
56 |
start_pos=start_pos,
|
57 |
delta=delta,
|
@@ -77,29 +79,32 @@ def generate(
|
|
77 |
max_length = min(max_length, tokenizer.model_max_length)
|
78 |
min_length = min(min_length, max_length)
|
79 |
output_tokens = model.generate(
|
80 |
-
|
81 |
-
attention_mask=tokenized_input.attention_mask[:, :prompt_size],
|
82 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
83 |
min_length=min_length,
|
84 |
max_length=max_length,
|
85 |
-
do_sample=
|
86 |
num_beams=num_beams,
|
87 |
repetition_penalty=float(repetition_penalty),
|
88 |
pad_token_id=tokenizer.eos_token_id,
|
|
|
89 |
)
|
90 |
|
91 |
output_tokens = output_tokens[:, prompt_size:]
|
92 |
-
|
93 |
output_tokens, skip_special_tokens=True
|
94 |
-
)
|
95 |
output_tokens_post = tokenizer(
|
96 |
-
|
|
|
|
|
|
|
97 |
).to(model.device)
|
98 |
msg_rates, tokens_infos = logits_processor.validate(
|
99 |
output_tokens_post.input_ids
|
100 |
)
|
101 |
|
102 |
-
return
|
103 |
|
104 |
|
105 |
def decrypt(
|
|
|
9 |
def generate(
|
10 |
tokenizer,
|
11 |
model,
|
12 |
+
prompt: str | list[str],
|
13 |
msg: bytes,
|
14 |
start_pos_p: list[int],
|
15 |
delta: float,
|
|
|
20 |
private_key: Union[int, None] = None,
|
21 |
min_new_tokens_ratio: float = 1,
|
22 |
max_new_tokens_ratio: float = 2,
|
23 |
+
do_sample: bool = True,
|
24 |
+
num_beams: int = 1,
|
25 |
repetition_penalty: float = 1.0,
|
26 |
+
generator: torch.Generator | None = None,
|
27 |
):
|
28 |
"""
|
29 |
Generate the sequence containing the hidden data.
|
|
|
47 |
start_pos_p[0], start_pos_p[1] + 1, (1,)
|
48 |
).item()
|
49 |
start_pos = int(start_pos) + window_length
|
50 |
+
tokenizer.pad_token = tokenizer.eos_token
|
51 |
+
|
52 |
+
tokenized_input = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
|
53 |
+
prompt_size = tokenized_input.input_ids.size(1)
|
54 |
|
|
|
|
|
|
|
55 |
logits_processor = EncryptorLogitsProcessor(
|
56 |
+
prompt_ids=tokenized_input.input_ids,
|
57 |
msg=msg,
|
58 |
start_pos=start_pos,
|
59 |
delta=delta,
|
|
|
79 |
max_length = min(max_length, tokenizer.model_max_length)
|
80 |
min_length = min(min_length, max_length)
|
81 |
output_tokens = model.generate(
|
82 |
+
**tokenized_input,
|
|
|
83 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
84 |
min_length=min_length,
|
85 |
max_length=max_length,
|
86 |
+
do_sample=do_sample,
|
87 |
num_beams=num_beams,
|
88 |
repetition_penalty=float(repetition_penalty),
|
89 |
pad_token_id=tokenizer.eos_token_id,
|
90 |
+
generator=generator,
|
91 |
)
|
92 |
|
93 |
output_tokens = output_tokens[:, prompt_size:]
|
94 |
+
output_texts = tokenizer.batch_decode(
|
95 |
output_tokens, skip_special_tokens=True
|
96 |
+
)
|
97 |
output_tokens_post = tokenizer(
|
98 |
+
output_texts,
|
99 |
+
return_tensors="pt",
|
100 |
+
add_special_tokens=False,
|
101 |
+
padding=True,
|
102 |
).to(model.device)
|
103 |
msg_rates, tokens_infos = logits_processor.validate(
|
104 |
output_tokens_post.input_ids
|
105 |
)
|
106 |
|
107 |
+
return output_texts, msg_rates, tokens_infos
|
108 |
|
109 |
|
110 |
def decrypt(
|