tnk2908 commited on
Commit
0186ed1
1 Parent(s): 45dff64

Refactor analysis process

Browse files
Files changed (10) hide show
  1. analyse.py +385 -327
  2. api.py +7 -2
  3. config.ini +2 -1
  4. demo.py +23 -6
  5. global_config.py +4 -1
  6. main.py +13 -0
  7. processors.py +7 -7
  8. schemes.py +1 -1
  9. seed_schemes/dummy_hash.py +2 -0
  10. stegno.py +19 -14
analyse.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import json
3
  import base64
4
  from argparse import ArgumentParser
@@ -16,50 +18,27 @@ rng = torch.Generator(device="cpu")
16
  rng.manual_seed(0)
17
 
18
 
19
- def load_msgs(msg_lens: list[int], file: str | None = None):
20
- msgs = None
21
- if file is not None and os.path.isfile(file):
22
- with open(file, "r") as f:
23
- msgs = json.load(f)
24
- if "readable" not in msgs and "random" not in msgs:
25
- msgs = None
26
- else:
27
- return msgs
28
-
29
- msgs = {
30
- "readable": [],
31
- "random": [],
32
- }
33
-
34
  c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True)
35
  iterator = iter(c4_en)
36
 
37
  for length in tqdm(msg_lens, desc="Loading messages"):
38
  random_msg = torch.randint(256, (length,), generator=rng)
39
- base64_msg = base64.b64encode(bytes(random_msg.tolist())).decode(
40
- "ascii"
41
- )
42
- msgs["random"].append(base64_msg)
43
 
44
  while True:
45
  readable_msg = next(iterator)["text"]
46
  try:
47
- readable_msg[:length].encode("ascii")
48
  break
49
  except Exception as e:
50
  continue
51
- msgs["readable"].append(readable_msg[:length])
52
 
53
  return msgs
54
 
55
 
56
- def load_prompts(n: int, prompt_size: int, file: str | None = None):
57
- prompts = None
58
- if file is not None and os.path.isfile(file):
59
- with open(file, "r") as f:
60
- prompts = json.load(f)
61
- return prompts
62
-
63
  prompts = []
64
 
65
  c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True)
@@ -68,14 +47,326 @@ def load_prompts(n: int, prompt_size: int, file: str | None = None):
68
  with tqdm(total=n, desc="Loading prompts") as pbar:
69
  while len(prompts) < n:
70
  text = next(iterator)["text"]
71
- if len(text) < prompt_size:
 
72
  continue
73
- prompts.append(text)
 
 
 
74
  pbar.update()
75
 
76
  return prompts
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def create_args():
80
  parser = ArgumentParser()
81
 
@@ -105,7 +396,7 @@ def create_args():
105
  parser.add_argument(
106
  "--num-prompts",
107
  type=int,
108
- default=500,
109
  help="Number of prompts",
110
  )
111
  parser.add_argument(
@@ -128,6 +419,12 @@ def create_args():
128
  default="gpt2",
129
  help="Model used to generate",
130
  )
 
 
 
 
 
 
131
  parser.add_argument(
132
  "--deltas",
133
  nargs=3,
@@ -140,12 +437,23 @@ def create_args():
140
  type=int,
141
  help="Bases used in base encoding",
142
  )
 
 
143
  parser.add_argument(
144
- "--judge-model",
145
- type=str,
146
- default="gpt2",
147
- help="Model used to compute score perplexity of generated text",
148
  )
 
 
 
 
 
 
 
 
 
 
149
  # Results
150
  parser.add_argument(
151
  "--repeat",
@@ -154,19 +462,19 @@ def create_args():
154
  help="How many times to repeat for each set of parameters, prompts and messages",
155
  )
156
  parser.add_argument(
157
- "--results-load-file",
158
  type=str,
159
  default=None,
160
- help="Where to load results",
161
  )
162
  parser.add_argument(
163
- "--results-save-file",
164
  type=str,
165
  default=None,
166
- help="Where to save results",
167
  )
168
  parser.add_argument(
169
- "--results-save-freq", type=int, default=100, help="Save frequency"
170
  )
171
  parser.add_argument(
172
  "--figs-dir",
@@ -178,268 +486,10 @@ def create_args():
178
  return parser.parse_args()
179
 
180
 
181
- def get_results(args, prompts, msgs):
182
- model, tokenizer = ModelFactory.load_model(args.gen_model)
183
- results = []
184
- total_gen = (
185
- len(prompts)
186
- * int(args.deltas[2])
187
- * len(args.bases)
188
- * args.repeat
189
- * sum([len(msgs[k]) for k in msgs])
190
- )
191
-
192
- with tqdm(total=total_gen, desc="Generating") as pbar:
193
- for k in msgs:
194
- msg_type = k
195
- for msg in msgs[k]:
196
- msg_bytes = (
197
- msg.encode("ascii")
198
- if k == "readable"
199
- else base64.b64decode(msg)
200
- )
201
- for base in args.bases:
202
- for delta in np.linspace(
203
- args.deltas[0], args.deltas[1], int(args.deltas[2])
204
- ):
205
- for prompt in prompts:
206
- for _ in range(args.repeat):
207
- text, msg_rate, tokens_info = generate(
208
- tokenizer=tokenizer,
209
- model=model,
210
- prompt=prompt,
211
- msg=msg_bytes,
212
- start_pos_p=[0],
213
- delta=delta,
214
- msg_base=base,
215
- seed_scheme="sha_left_hash",
216
- window_length=1,
217
- private_key=0,
218
- min_new_tokens_ratio=1,
219
- max_new_tokens_ratio=2,
220
- num_beams=4,
221
- repetition_penalty=1.5,
222
- prompt_size=args.prompt_size,
223
- )
224
- results.append(
225
- {
226
- "msg_type": msg_type,
227
- "delta": delta.item(),
228
- "base": base,
229
- "perplexity": ModelFactory.compute_perplexity(
230
- args.judge_model, text
231
- ),
232
- "msg_rate": msg_rate,
233
- "msg_len": len(msg_bytes),
234
- }
235
- )
236
- pbar.set_postfix(
237
- {
238
- "perplexity": results[-1]["perplexity"],
239
- "msg_rate": results[-1]["msg_rate"],
240
- "msg_len": len(msg_bytes),
241
- "delta": delta.item(),
242
- "base": base,
243
- }
244
- )
245
- if (
246
- len(results) + 1
247
- ) % args.results_save_freq == 0:
248
- if args.results_save_file:
249
- os.makedirs(
250
- os.path.dirname(
251
- args.results_save_file
252
- ),
253
- exist_ok=True,
254
- )
255
- with open(
256
- args.results_save_file, "w"
257
- ) as f:
258
- json.dump(results, f)
259
- print(
260
- f"Saved results to {args.results_save_file}"
261
- )
262
-
263
- pbar.update()
264
- return results
265
-
266
-
267
- def process_results(results, save_dir):
268
- data = {
269
- "perplexities": {
270
- "random": {},
271
- "readable": {},
272
- },
273
- "msg_rates": {
274
- "random": {},
275
- "readable": {},
276
- },
277
- }
278
- for r in results:
279
- msg_type = r["msg_type"]
280
- base = r["base"]
281
- delta = r["delta"]
282
- msg_rate = r["msg_rate"]
283
- msg_len = r["msg_len"]
284
- perplexity = r["perplexity"]
285
-
286
- if (base, delta, msg_len) not in data["msg_rates"][msg_type]:
287
- data["msg_rates"][msg_type][(base, delta, msg_len)] = []
288
- data["msg_rates"][msg_type][(base, delta, msg_len)].append(msg_rate)
289
-
290
- if (base, delta, msg_len) not in data["perplexities"][msg_type]:
291
- data["perplexities"][msg_type][(base, delta, msg_len)] = []
292
- data["perplexities"][msg_type][(base, delta, msg_len)].append(
293
- perplexity
294
- )
295
-
296
- bases = {
297
- "perplexities": {
298
- "random": [],
299
- "readable": [],
300
- },
301
- "msg_rates": {
302
- "random": [],
303
- "readable": [],
304
- },
305
- }
306
- deltas = {
307
- "perplexities": {
308
- "random": [],
309
- "readable": [],
310
- },
311
- "msg_rates": {
312
- "random": [],
313
- "readable": [],
314
- },
315
- }
316
- msgs_lens = {
317
- "perplexities": {
318
- "random": [],
319
- "readable": [],
320
- },
321
- "msg_rates": {
322
- "random": [],
323
- "readable": [],
324
- },
325
- }
326
- values = {
327
- "perplexities": {
328
- "random": [],
329
- "readable": [],
330
- },
331
- "msg_rates": {
332
- "random": [],
333
- "readable": [],
334
- },
335
- }
336
- base_set = set()
337
- delta_set = set()
338
- msgs_lens_set = set()
339
- for metric in data:
340
- for msg_type in data[metric]:
341
- for k in data[metric][msg_type]:
342
- s = sum(data[metric][msg_type][k])
343
- cnt = len(data[metric][msg_type][k])
344
- data[metric][msg_type][k] = s / cnt
345
-
346
- bases[metric][msg_type].append(k[0])
347
- deltas[metric][msg_type].append(k[1])
348
- msgs_lens[metric][msg_type].append(k[2])
349
- values[metric][msg_type].append(s / cnt)
350
- base_set.add(k[0])
351
- delta_set.add(k[1])
352
- msgs_lens_set.add(k[2])
353
-
354
- for metric in data:
355
- for msg_type in data[metric]:
356
- bases[metric][msg_type] = np.array(
357
- bases[metric][msg_type], dtype=np.int64
358
- )
359
- deltas[metric][msg_type] = np.array(
360
- deltas[metric][msg_type], dtype=np.int64
361
- )
362
- msgs_lens[metric][msg_type] = np.array(
363
- msgs_lens[metric][msg_type], dtype=np.int64
364
- )
365
-
366
- values[metric][msg_type] = np.array(
367
- values[metric][msg_type], dtype=np.float64
368
- )
369
-
370
- os.makedirs(save_dir, exist_ok=True)
371
- for metric in data:
372
- for msg_type in data[metric]:
373
- fig = plt.figure(dpi=300)
374
- s = lambda x: 3.0 + x * (30 if metric == "msg_rates" else 10)
375
- plt.scatter(
376
- bases[metric][msg_type],
377
- deltas[metric][msg_type],
378
- s(values[metric][msg_type]),
379
- )
380
- plt.savefig(
381
- os.path.join(save_dir, f"{metric}_{msg_type}_scatter.pdf"),
382
- bbox_inches="tight",
383
- )
384
- plt.close(fig)
385
-
386
- os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
387
- for metric in data:
388
- for msg_type in data[metric]:
389
- fig = plt.figure(dpi=300)
390
- for base_value in base_set:
391
- deltas_avg = np.array(list(sorted(delta_set)))
392
- values_avg = np.zeros_like(deltas_avg, dtype=np.float64)
393
- for i in range(len(deltas_avg)):
394
- mask = (deltas[metric][msg_type] == deltas_avg[i]) & (
395
- bases[metric][msg_type] == base_value
396
- )
397
- values_avg[i] = np.mean(values[metric][msg_type][mask])
398
- plt.plot(deltas_avg, values_avg, label=f"Base {base_value}")
399
-
400
- plt.legend()
401
- plt.savefig(
402
- os.path.join(
403
- save_dir,
404
- f"delta_effect/{metric}_{msg_type}.pdf",
405
- ),
406
- bbox_inches="tight",
407
- )
408
- plt.close(fig)
409
-
410
- os.makedirs(os.path.join(save_dir, "msg_len_effect"), exist_ok=True)
411
- for metric in data:
412
- for msg_type in data[metric]:
413
- fig = plt.figure(dpi=300)
414
- for base_value in base_set:
415
- msgs_lens_avg = np.array(sorted(list(msgs_lens_set)))
416
- values_avg = np.zeros_like(msgs_lens_avg, dtype=np.float64)
417
- for i in range(len(msgs_lens_avg)):
418
- mask = (msgs_lens[metric][msg_type] == msgs_lens_avg[i]) & (
419
- bases[metric][msg_type] == base_value
420
- )
421
- values_avg[i] = np.mean(values[metric][msg_type][mask])
422
-
423
- plt.plot(msgs_lens_avg, values_avg, label=f"Base {base_value}")
424
-
425
- plt.legend()
426
- plt.savefig(
427
- os.path.join(
428
- save_dir,
429
- f"msg_len_effect/{metric}_{msg_type}.pdf",
430
- ),
431
- bbox_inches="tight",
432
- )
433
- plt.close(fig)
434
-
435
-
436
  def main(args):
437
- if not args.results_load_file:
438
- prompts = load_prompts(
439
- args.num_prompts,
440
- args.prompt_size,
441
- args.prompts_file if not args.overwrite else None,
442
- )
443
 
444
  msgs_lens = []
445
  for i in np.linspace(
@@ -451,36 +501,44 @@ def main(args):
451
  for _ in range(args.msgs_per_length):
452
  msgs_lens.append(i)
453
 
454
- msgs = load_msgs(
455
- msgs_lens,
456
- args.msgs_file if not args.overwrite else None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  )
458
-
459
- if args.msgs_file:
460
- if not os.path.isfile(args.msgs_file) or args.overwrite:
461
- os.makedirs(os.path.dirname(args.msgs_file), exist_ok=True)
462
- with open(args.msgs_file, "w") as f:
463
- json.dump(msgs, f)
464
- print(f"Saved messages to {args.msgs_file}")
465
- if args.prompts_file:
466
- if not os.path.isfile(args.prompts_file) or args.overwrite:
467
- os.makedirs(os.path.dirname(args.prompts_file), exist_ok=True)
468
- with open(args.prompts_file, "w") as f:
469
- json.dump(prompts, f)
470
- print(f"Saved prompts to {args.prompts_file}")
471
- results = get_results(args, prompts, msgs)
472
  else:
473
- with open(args.results_load_file, "r") as f:
474
- results = json.load(f)
 
 
 
475
 
476
- if args.results_save_file:
477
- os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True)
478
- with open(args.results_save_file, "w") as f:
479
- json.dump(results, f)
480
- print(f"Saved results to {args.results_save_file}")
481
 
482
- if args.figs_dir:
483
- process_results(results, args.figs_dir)
484
 
485
 
486
  if __name__ == "__main__":
 
1
  import os
2
+ from datetime import datetime
3
+ from copy import deepcopy
4
  import json
5
  import base64
6
  from argparse import ArgumentParser
 
18
  rng.manual_seed(0)
19
 
20
 
21
+ def load_msgs(msg_lens: list[int]):
22
+ msgs = []
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True)
24
  iterator = iter(c4_en)
25
 
26
  for length in tqdm(msg_lens, desc="Loading messages"):
27
  random_msg = torch.randint(256, (length,), generator=rng)
28
+ msgs.append(["random", bytes(random_msg.tolist())])
 
 
 
29
 
30
  while True:
31
  readable_msg = next(iterator)["text"]
32
  try:
33
+ msgs.append(["readable", readable_msg[:length].encode("ascii")])
34
  break
35
  except Exception as e:
36
  continue
 
37
 
38
  return msgs
39
 
40
 
41
+ def load_prompts(tokenizer, n: int, prompt_size: int):
 
 
 
 
 
 
42
  prompts = []
43
 
44
  c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True)
 
47
  with tqdm(total=n, desc="Loading prompts") as pbar:
48
  while len(prompts) < n:
49
  text = next(iterator)["text"]
50
+ input_ids = tokenizer.encode(text, return_tensors="pt")
51
+ if input_ids.size(1) < prompt_size:
52
  continue
53
+ truncated_text = tokenizer.batch_decode(input_ids[:, :prompt_size])[
54
+ 0
55
+ ]
56
+ prompts.append(truncated_text)
57
  pbar.update()
58
 
59
  return prompts
60
 
61
 
62
+ class AnalyseProcessor(object):
63
+ params_names = [
64
+ "msgs",
65
+ "bases",
66
+ "deltas",
67
+ ]
68
+
69
+ def __init__(
70
+ self,
71
+ save_file: str,
72
+ save_freq: int | None = None,
73
+ gen_model: str | None = None,
74
+ judge_model: str | None = None,
75
+ msgs: list[bytes] | None = None,
76
+ bases: list[int] | None = None,
77
+ deltas: list[float] | None = None,
78
+ prompts: list[str] | None = None,
79
+ repeat: int = 1,
80
+ gen_params: dict | None = None,
81
+ batch_size: int = 1,
82
+ ):
83
+ self.save_file = save_file
84
+ self.save_freq = save_freq
85
+ self.data = {
86
+ "params": {
87
+ "gen_model": gen_model,
88
+ "judge_model": judge_model,
89
+ "ptrs": {
90
+ "msgs": 0,
91
+ "bases": 0,
92
+ "deltas": 0,
93
+ },
94
+ "values": {
95
+ "msgs": msgs,
96
+ "bases": bases,
97
+ "deltas": deltas,
98
+ },
99
+ "prompts": prompts,
100
+ "batch_size": batch_size,
101
+ "repeat": repeat,
102
+ "gen": gen_params,
103
+ },
104
+ "results": [],
105
+ }
106
+ self.__pbar = None
107
+ self.last_saved = None
108
+ self.skip_first = False
109
+
110
+ def run(self, depth=0):
111
+ if self.__pbar is None:
112
+ total = 1
113
+ for v in self.data["params"]["values"].keys():
114
+ if v is None:
115
+ raise RuntimeError(f"values must not be None when running")
116
+
117
+ initial = 0
118
+ for param_name in self.params_names[::-1]:
119
+ initial += total * self.data["params"]["ptrs"][param_name]
120
+ total *= len(self.data["params"]["values"][param_name])
121
+
122
+ if self.skip_first:
123
+ initial += 1
124
+
125
+ self.__pbar = tqdm(
126
+ desc="Generating",
127
+ total=total,
128
+ initial=initial,
129
+ )
130
+
131
+ if depth < len(self.params_names):
132
+ param_name = self.params_names[depth]
133
+
134
+ while self.data["params"]["ptrs"][param_name] < len(
135
+ self.data["params"]["values"][param_name]
136
+ ):
137
+ self.run(depth + 1)
138
+ self.data["params"]["ptrs"][param_name] = (
139
+ self.data["params"]["ptrs"][param_name] + 1
140
+ )
141
+
142
+ self.data["params"]["ptrs"][param_name] = 0
143
+ if depth == 0:
144
+ self.save_data(self.save_file)
145
+ else:
146
+ if self.skip_first:
147
+ self.skip_first = False
148
+ return
149
+ prompts = self.data["params"]["prompts"]
150
+
151
+ msg_ptr = self.data["params"]["ptrs"]["msgs"]
152
+ msg_type, msg = self.data["params"]["values"]["msgs"][msg_ptr]
153
+
154
+ base_ptr = self.data["params"]["ptrs"]["bases"]
155
+ base = self.data["params"]["values"]["bases"][base_ptr]
156
+
157
+ delta_ptr = self.data["params"]["ptrs"]["deltas"]
158
+ delta = self.data["params"]["values"]["deltas"][delta_ptr]
159
+
160
+ model, tokenizer = ModelFactory.load_model(
161
+ self.data["params"]["gen_model"]
162
+ )
163
+ l = 0
164
+ while l < len(prompts):
165
+ start = datetime.now()
166
+ r = l + self.data["params"]["batch_size"]
167
+ r = min(r, len(prompts))
168
+
169
+ texts, msgs_rates, _ = generate(
170
+ model=model,
171
+ tokenizer=tokenizer,
172
+ prompt=prompts[l:r],
173
+ msg=msg,
174
+ msg_base=base,
175
+ delta=delta,
176
+ **self.data["params"]["gen"],
177
+ )
178
+ end = datetime.now()
179
+ for i in range(len(texts)):
180
+ prompt_ptr = l + i
181
+ text = texts[i]
182
+ msg_rate = msgs_rates[i]
183
+ self.data["results"].append(
184
+ {
185
+ "ptrs": {
186
+ "prompts": prompt_ptr,
187
+ "msgs": msg_ptr,
188
+ "bases": base_ptr,
189
+ "deltas": delta_ptr,
190
+ },
191
+ "perplexity": ModelFactory.compute_perplexity(
192
+ self.data["params"]["judge_model"], text
193
+ ),
194
+ "text": text,
195
+ "msg_rate": msg_rate,
196
+ "run_time (ms)": (end - start).microseconds
197
+ / len(texts),
198
+ }
199
+ )
200
+ l += self.data["params"]["batch_size"]
201
+
202
+ postfix = {
203
+ "base": base,
204
+ "msg_len": len(msg),
205
+ "delta": delta,
206
+ }
207
+ self.__pbar.refresh()
208
+ if self.save_freq and (self.__pbar.n + 1) % self.save_freq == 0:
209
+ self.save_data(self.save_file)
210
+
211
+ if self.last_saved is not None:
212
+ seconds = (datetime.now() - self.last_saved).seconds
213
+ minutes = seconds // 60
214
+ hours = minutes // 60
215
+ minutes %= 60
216
+ seconds %= 60
217
+ postfix["last_saved"] = f"{hours}:{minutes}:{seconds} ago"
218
+
219
+ self.__pbar.set_postfix(postfix)
220
+ self.__pbar.update()
221
+
222
+ def __get_mean(self, ptrs: dict, value_name: str):
223
+ s = 0
224
+ cnt = 0
225
+ for r in self.data["results"]:
226
+ msg_type, msg = self.data["params"]["values"]["msgs"][
227
+ r["ptrs"]["msgs"]
228
+ ]
229
+ valid = True
230
+ for k in ptrs:
231
+ if (
232
+ (k in r["ptrs"] and r["ptrs"][k] != ptrs[k])
233
+ or (k == "msg_len" and len(msg) != ptrs[k])
234
+ or (k == "msg_type" and msg_type != ptrs[k])
235
+ ):
236
+ valid = False
237
+ break
238
+
239
+ if valid:
240
+ s += r[value_name]
241
+ cnt += 1
242
+ if cnt == 0:
243
+ cnt = 1
244
+ return s / cnt
245
+
246
+ def plot(self, figs_dir: str):
247
+ os.makedirs(figs_dir, exist_ok=True)
248
+ msg_set = set()
249
+ for msg_type, msg in self.data["params"]["values"]["msgs"]:
250
+ msg_set.add((msg_type, len(msg)))
251
+ msg_set = sorted(msg_set)
252
+
253
+ # Delta effect
254
+ os.makedirs(os.path.join(figs_dir, "delta_effect"), exist_ok=True)
255
+ for value_name in ["perplexity", "msg_rate"]:
256
+ fig = plt.figure(dpi=300)
257
+ for base_ptr, base in enumerate(
258
+ self.data["params"]["values"]["bases"]
259
+ ):
260
+ for msg_type, msg_len in msg_set:
261
+ x = []
262
+ y = []
263
+ for delta_ptr, delta in enumerate(
264
+ self.data["params"]["values"]["deltas"]
265
+ ):
266
+ x.append(delta)
267
+ y.append(
268
+ self.__get_mean(
269
+ ptrs={
270
+ "bases": base_ptr,
271
+ "msg_type": msg_type,
272
+ "msg_len": msg_len,
273
+ "deltas": delta_ptr,
274
+ },
275
+ value_name=value_name,
276
+ )
277
+ )
278
+ plt.plot(
279
+ x,
280
+ y,
281
+ label=f"B={base}, msg_type={msg_type}, msg_len={msg_len}",
282
+ )
283
+ plt.ylim(ymin=0)
284
+ plt.legend()
285
+ plt.savefig(
286
+ os.path.join(figs_dir, "delta_effect", f"{value_name}.pdf"),
287
+ bbox_inches="tight",
288
+ )
289
+ plt.close(fig)
290
+
291
+ # Message length effect
292
+ os.makedirs(os.path.join(figs_dir, "msg_len_effect"), exist_ok=True)
293
+ for value_name in ["perplexity", "msg_rate"]:
294
+ fig = plt.figure(dpi=300)
295
+ for base_ptr, base in enumerate(
296
+ self.data["params"]["values"]["bases"]
297
+ ):
298
+ for delta_ptr, delta in enumerate(
299
+ self.data["params"]["values"]["deltas"]
300
+ ):
301
+ x = {}
302
+ y = {}
303
+ for msg_type, msg_len in msg_set:
304
+ if msg_type not in x:
305
+ x[msg_type] = []
306
+ if msg_type not in y:
307
+ y[msg_type] = []
308
+ x[msg_type].append(msg_len)
309
+ y[msg_type].append(
310
+ self.__get_mean(
311
+ ptrs={
312
+ "bases": base_ptr,
313
+ "msg_type": msg_type,
314
+ "msg_len": msg_len,
315
+ "deltas": delta_ptr,
316
+ },
317
+ value_name=value_name,
318
+ )
319
+ )
320
+ for msg_type in x:
321
+ plt.plot(
322
+ x[msg_type],
323
+ y[msg_type],
324
+ label=f"B={base}, msg_type={msg_type}, delta={delta}",
325
+ )
326
+ plt.ylim(ymin=0)
327
+ plt.legend()
328
+ plt.savefig(
329
+ os.path.join(figs_dir, "msg_len_effect", f"{value_name}.pdf"),
330
+ bbox_inches="tight",
331
+ )
332
+ plt.close(fig)
333
+ print(f"Saved figures to {figs_dir}")
334
+
335
+ def save_data(self, file_name: str):
336
+ if file_name is None:
337
+ return
338
+ os.makedirs(os.path.dirname(file_name), exist_ok=True)
339
+ data = deepcopy(self.data)
340
+ for i in range(len(data["params"]["values"]["msgs"])):
341
+ msg_type, msg = data["params"]["values"]["msgs"][i]
342
+ if msg_type == "random":
343
+ str_msg = base64.b64encode(msg).decode("ascii")
344
+ else:
345
+ str_msg = msg.decode("ascii")
346
+ data["params"]["values"]["msgs"][i] = [msg_type, str_msg]
347
+
348
+ with open(file_name, "w") as f:
349
+ json.dump(data, f, indent=2)
350
+ if self.__pbar is None:
351
+ print(f"Saved AnalyseProcessor data to {file_name}")
352
+ else:
353
+ self.last_saved = datetime.now()
354
+
355
+ def load_data(self, file_name: str):
356
+ with open(file_name, "r") as f:
357
+ self.data = json.load(f)
358
+ for i in range(len(self.data["params"]["values"]["msgs"])):
359
+ msg_type, str_msg = self.data["params"]["values"]["msgs"][i]
360
+ if msg_type == "random":
361
+ msg = base64.b64decode(str_msg)
362
+ else:
363
+ msg = str_msg.encode("ascii")
364
+ self.data["params"]["values"]["msgs"][i] = [msg_type, msg]
365
+
366
+ self.skip_first = len(self.data["results"]) > 0
367
+ self.__pbar = None
368
+
369
+
370
  def create_args():
371
  parser = ArgumentParser()
372
 
 
396
  parser.add_argument(
397
  "--num-prompts",
398
  type=int,
399
+ default=10,
400
  help="Number of prompts",
401
  )
402
  parser.add_argument(
 
419
  default="gpt2",
420
  help="Model used to generate",
421
  )
422
+ parser.add_argument(
423
+ "--judge-model",
424
+ type=str,
425
+ default="gpt2",
426
+ help="Model used to compute score perplexity of generated text",
427
+ )
428
  parser.add_argument(
429
  "--deltas",
430
  nargs=3,
 
437
  type=int,
438
  help="Bases used in base encoding",
439
  )
440
+
441
+ # Generate parameters
442
  parser.add_argument(
443
+ "--do-sample",
444
+ action="store_true",
445
+ help="Whether to use sample or greedy search",
 
446
  )
447
+ parser.add_argument(
448
+ "--num-beams", type=int, default=1, help="How many beams to use"
449
+ )
450
+ parser.add_argument(
451
+ "--batch-size",
452
+ type=int,
453
+ default=1,
454
+ help="Batch size used for generating",
455
+ )
456
+
457
  # Results
458
  parser.add_argument(
459
  "--repeat",
 
462
  help="How many times to repeat for each set of parameters, prompts and messages",
463
  )
464
  parser.add_argument(
465
+ "--load-file",
466
  type=str,
467
  default=None,
468
+ help="Where to load data for AnalyseProcessor",
469
  )
470
  parser.add_argument(
471
+ "--save-file",
472
  type=str,
473
  default=None,
474
+ help="Where to save data for AnalyseProcessor",
475
  )
476
  parser.add_argument(
477
+ "--save-freq", type=int, default=100, help="Save frequency"
478
  )
479
  parser.add_argument(
480
  "--figs-dir",
 
486
  return parser.parse_args()
487
 
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  def main(args):
490
+ if not args.load_file:
491
+ model, tokenizer = ModelFactory.load_model(args.gen_model)
492
+ prompts = load_prompts(tokenizer, args.num_prompts, args.prompt_size)
 
 
 
493
 
494
  msgs_lens = []
495
  for i in np.linspace(
 
501
  for _ in range(args.msgs_per_length):
502
  msgs_lens.append(i)
503
 
504
+ msgs = load_msgs(msgs_lens)
505
+
506
+ processor = AnalyseProcessor(
507
+ save_file=args.save_file,
508
+ save_freq=args.save_freq,
509
+ gen_model=args.gen_model,
510
+ judge_model=args.judge_model,
511
+ msgs=msgs,
512
+ bases=args.bases,
513
+ deltas=np.linspace(
514
+ args.deltas[0], args.deltas[1], int(args.deltas[2])
515
+ ).tolist(),
516
+ prompts=prompts,
517
+ batch_size=args.batch_size,
518
+ gen_params=dict(
519
+ start_pos_p=[0],
520
+ seed_scheme="dummy_hash",
521
+ window_length=1,
522
+ min_new_tokens_ratio=1,
523
+ max_new_tokens_ratio=1,
524
+ do_sample=args.do_sample,
525
+ num_beams=args.num_beams,
526
+ repetition_penalty=1.0,
527
+ ),
528
  )
529
+ processor.save_data(args.save_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  else:
531
+ processor = AnalyseProcessor(
532
+ save_file=args.save_file,
533
+ save_freq=args.save_freq,
534
+ )
535
+ processor.load_data(args.load_file)
536
 
537
+ processor.run()
538
+ processor.plot(args.figs_dir)
 
 
 
539
 
540
+ # if args.figs_dir:
541
+ # process_results(results, args.figs_dir)
542
 
543
 
544
  if __name__ == "__main__":
api.py CHANGED
@@ -34,7 +34,7 @@ async def encrypt_api(
34
  ):
35
  byte_msg = base64.b64decode(body.msg)
36
  model, tokenizer = ModelFactory.load_model(body.gen_model)
37
- text, msg_rate, tokens_info = generate(
38
  tokenizer=tokenizer,
39
  model=model,
40
  prompt=body.prompt,
@@ -49,7 +49,11 @@ async def encrypt_api(
49
  num_beams=body.num_beams,
50
  repetition_penalty=body.repetition_penalty,
51
  )
52
- return {"text": text, "msg_rate": msg_rate, "tokens_info": tokens_info}
 
 
 
 
53
 
54
 
55
  @app.post(
@@ -114,6 +118,7 @@ async def default_config():
114
  "max_new_tokens_ratio": GlobalConfig.get(
115
  "encrypt.default", "max_new_tokens_ratio"
116
  ),
 
117
  "num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
118
  "repetition_penalty": GlobalConfig.get(
119
  "encrypt.default", "repetition_penalty"
 
34
  ):
35
  byte_msg = base64.b64decode(body.msg)
36
  model, tokenizer = ModelFactory.load_model(body.gen_model)
37
+ texts, msgs_rates, tokens_infos = generate(
38
  tokenizer=tokenizer,
39
  model=model,
40
  prompt=body.prompt,
 
49
  num_beams=body.num_beams,
50
  repetition_penalty=body.repetition_penalty,
51
  )
52
+ return {
53
+ "texts": texts,
54
+ "msgs_rates": msgs_rates,
55
+ "tokens_info": tokens_infos,
56
+ }
57
 
58
 
59
  @app.post(
 
118
  "max_new_tokens_ratio": GlobalConfig.get(
119
  "encrypt.default", "max_new_tokens_ratio"
120
  ),
121
+ "do_sample": GlobalConfig.get("encrypt.default", "do_sample"),
122
  "num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
123
  "repetition_penalty": GlobalConfig.get(
124
  "encrypt.default", "repetition_penalty"
config.ini CHANGED
@@ -34,7 +34,8 @@ window_length = int:1
34
  private_key = int:0
35
  min_new_tokens_ratio = float:1.0
36
  max_new_tokens_ratio = float:2.0
37
- num_beams = int:4
 
38
  repetition_penalty = float:1.0
39
 
40
  [decrypt.default]
 
34
  private_key = int:0
35
  min_new_tokens_ratio = float:1.0
36
  max_new_tokens_ratio = float:2.0
37
+ do_sample = bool:0
38
+ num_beams = int:1
39
  repetition_penalty = float:1.0
40
 
41
  [decrypt.default]
demo.py CHANGED
@@ -17,12 +17,14 @@ def enc_fn(
17
  seed_scheme: str,
18
  window_length: int,
19
  private_key: int,
 
 
20
  max_new_tokens_ratio: float,
21
  num_beams: int,
22
  repetition_penalty: float,
23
  ):
24
  model, tokenizer = ModelFactory.load_model(gen_model)
25
- text, msg_rate, tokens_info = generate(
26
  tokenizer=tokenizer,
27
  model=model,
28
  prompt=prompt,
@@ -33,12 +35,14 @@ def enc_fn(
33
  seed_scheme=seed_scheme,
34
  window_length=window_length,
35
  private_key=private_key,
 
 
36
  max_new_tokens_ratio=max_new_tokens_ratio,
37
  num_beams=num_beams,
38
  repetition_penalty=repetition_penalty,
39
  )
40
  highlight_base = []
41
- for token in tokens_info:
42
  stat = None
43
  if token["base_msg"] != -1:
44
  if token["base_msg"] == token["base_enc"]:
@@ -48,8 +52,8 @@ def enc_fn(
48
  highlight_base.append((repr(token["token"])[1:-1], stat))
49
 
50
  highlight_byte = []
51
- for i, token in enumerate(tokens_info):
52
- if i == 0 or tokens_info[i - 1]["byte_id"] != token["byte_id"]:
53
  stat = None
54
  if token["byte_msg"] != -1:
55
  if token["byte_msg"] == token["byte_enc"]:
@@ -60,7 +64,12 @@ def enc_fn(
60
  else:
61
  highlight_byte[-1][0] += repr(token["token"])[1:-1]
62
 
63
- return text, highlight_base, highlight_byte, round(msg_rate * 100, 2)
 
 
 
 
 
64
 
65
 
66
  def dec_fn(
@@ -108,13 +117,21 @@ if __name__ == "__main__":
108
  int(GlobalConfig.get("encrypt.default", "window_length"))
109
  ),
110
  gr.Number(int(GlobalConfig.get("encrypt.default", "private_key"))),
 
 
 
 
 
 
111
  gr.Number(
112
  float(
113
  GlobalConfig.get("encrypt.default", "max_new_tokens_ratio")
114
  )
115
  ),
116
  gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
117
- gr.Number(float(GlobalConfig.get("encrypt.default", "repetition_penalty"))),
 
 
118
  ],
119
  outputs=[
120
  gr.Textbox(
 
17
  seed_scheme: str,
18
  window_length: int,
19
  private_key: int,
20
+ do_sample: bool,
21
+ min_new_tokens_ratio: float,
22
  max_new_tokens_ratio: float,
23
  num_beams: int,
24
  repetition_penalty: float,
25
  ):
26
  model, tokenizer = ModelFactory.load_model(gen_model)
27
+ texts, msgs_rates, tokens_infos = generate(
28
  tokenizer=tokenizer,
29
  model=model,
30
  prompt=prompt,
 
35
  seed_scheme=seed_scheme,
36
  window_length=window_length,
37
  private_key=private_key,
38
+ do_sample=do_sample,
39
+ min_new_tokens_ratio=min_new_tokens_ratio,
40
  max_new_tokens_ratio=max_new_tokens_ratio,
41
  num_beams=num_beams,
42
  repetition_penalty=repetition_penalty,
43
  )
44
  highlight_base = []
45
+ for token in tokens_infos[0]:
46
  stat = None
47
  if token["base_msg"] != -1:
48
  if token["base_msg"] == token["base_enc"]:
 
52
  highlight_base.append((repr(token["token"])[1:-1], stat))
53
 
54
  highlight_byte = []
55
+ for i, token in enumerate(tokens_infos[0]):
56
+ if i == 0 or tokens_infos[0][i - 1]["byte_id"] != token["byte_id"]:
57
  stat = None
58
  if token["byte_msg"] != -1:
59
  if token["byte_msg"] == token["byte_enc"]:
 
64
  else:
65
  highlight_byte[-1][0] += repr(token["token"])[1:-1]
66
 
67
+ return (
68
+ texts[0],
69
+ highlight_base,
70
+ highlight_byte,
71
+ round(msgs_rates[0] * 100, 2),
72
+ )
73
 
74
 
75
  def dec_fn(
 
117
  int(GlobalConfig.get("encrypt.default", "window_length"))
118
  ),
119
  gr.Number(int(GlobalConfig.get("encrypt.default", "private_key"))),
120
+ gr.Number(bool(GlobalConfig.get("encrypt.default", "do_sample"))),
121
+ gr.Number(
122
+ float(
123
+ GlobalConfig.get("encrypt.default", "min_new_tokens_ratio")
124
+ )
125
+ ),
126
  gr.Number(
127
  float(
128
  GlobalConfig.get("encrypt.default", "max_new_tokens_ratio")
129
  )
130
  ),
131
  gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
132
+ gr.Number(
133
+ float(GlobalConfig.get("encrypt.default", "repetition_penalty"))
134
+ ),
135
  ],
136
  outputs=[
137
  gr.Textbox(
global_config.py CHANGED
@@ -9,7 +9,7 @@ class GlobalConfig:
9
 
10
  @classmethod
11
  def get_section(cls, section_name):
12
- if section_name in cls.config :
13
  return cls.config[section_name].keys()
14
  else:
15
  return None
@@ -27,6 +27,9 @@ class GlobalConfig:
27
  value = float(value)
28
  elif type_name == "int":
29
  value = int(value)
 
 
 
30
  return value
31
  else:
32
  return None
 
9
 
10
  @classmethod
11
  def get_section(cls, section_name):
12
+ if section_name in cls.config:
13
  return cls.config[section_name].keys()
14
  else:
15
  return None
 
27
  value = float(value)
28
  elif type_name == "int":
29
  value = int(value)
30
+ elif type_name == "bool":
31
+ value = bool(value)
32
+
33
  return value
34
  else:
35
  return None
main.py CHANGED
@@ -61,6 +61,17 @@ def create_args():
61
  default=GlobalConfig.get("encrypt.default", "num_beams"),
62
  help="Number of beams used in beam search",
63
  )
 
 
 
 
 
 
 
 
 
 
 
64
  parser.add_argument(
65
  "--max-new-tokens-ratio",
66
  type=float,
@@ -183,6 +194,8 @@ def main(args):
183
  window_length=args.window_length,
184
  salt_key=args.salt_key,
185
  private_key=args.private_key,
 
 
186
  max_new_tokens_ratio=args.max_new_tokens_ratio,
187
  num_beams=args.num_beams,
188
  )
 
61
  default=GlobalConfig.get("encrypt.default", "num_beams"),
62
  help="Number of beams used in beam search",
63
  )
64
+ parser.add_argument(
65
+ "--do-sample",
66
+ action="store_true",
67
+ help="Whether to do sample or greedy search",
68
+ )
69
+ parser.add_argument(
70
+ "--min-new-tokens-ratio",
71
+ type=float,
72
+ default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
73
+ help="Ratio of min new tokens to minimum tokens required to hide message",
74
+ )
75
  parser.add_argument(
76
  "--max-new-tokens-ratio",
77
  type=float,
 
194
  window_length=args.window_length,
195
  salt_key=args.salt_key,
196
  private_key=args.private_key,
197
+ do_sample=args.do_sample,
198
+ min_new_tokens_ratio=args.min_new_tokens_ratio,
199
  max_new_tokens_ratio=args.max_new_tokens_ratio,
200
  num_beams=args.num_beams,
201
  )
processors.py CHANGED
@@ -116,10 +116,10 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
116
  delta: bias add to scores of token in valid list.
117
  """
118
  super().__init__(*args, **kwargs)
119
- if prompt_ids.size(0) != 1:
120
- raise RuntimeError(
121
- "EncryptorLogitsProcessor does not support multiple prompts input."
122
- )
123
 
124
  self.prompt_size = prompt_ids.size(1)
125
  self.start_pos = start_pos
@@ -192,7 +192,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
192
  return base_msg, byte_msg, base_enc_msg, byte_enc_msg
193
 
194
  def validate(self, input_ids_batch: torch.Tensor):
195
- res = []
196
  tokens_infos = []
197
  for input_ids in input_ids_batch:
198
  # Initialization
@@ -214,7 +214,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
214
  for i in range(min(len(enc_msg), len(self.raw_msg))):
215
  if self.raw_msg[i] == enc_msg[i]:
216
  cnt += 1
217
- res.append(cnt / len(self.raw_msg))
218
 
219
  base_msg, byte_msg, base_enc_msg, byte_enc_msg = (
220
  self.__map_input_ids(input_ids, base_arr, byte_arr)
@@ -234,7 +234,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
234
  )
235
  tokens_infos.append(tokens)
236
 
237
- return res, tokens_infos
238
 
239
 
240
  class DecryptorProcessor(BaseProcessor):
 
116
  delta: bias add to scores of token in valid list.
117
  """
118
  super().__init__(*args, **kwargs)
119
+ # if prompt_ids.size(0) != 1:
120
+ # raise RuntimeError(
121
+ # "EncryptorLogitsProcessor does not support multiple prompts input."
122
+ # )
123
 
124
  self.prompt_size = prompt_ids.size(1)
125
  self.start_pos = start_pos
 
192
  return base_msg, byte_msg, base_enc_msg, byte_enc_msg
193
 
194
  def validate(self, input_ids_batch: torch.Tensor):
195
+ msgs_rates = []
196
  tokens_infos = []
197
  for input_ids in input_ids_batch:
198
  # Initialization
 
214
  for i in range(min(len(enc_msg), len(self.raw_msg))):
215
  if self.raw_msg[i] == enc_msg[i]:
216
  cnt += 1
217
+ msgs_rates.append(cnt / len(self.raw_msg))
218
 
219
  base_msg, byte_msg, base_enc_msg, byte_enc_msg = (
220
  self.__map_input_ids(input_ids, base_arr, byte_arr)
 
234
  )
235
  tokens_infos.append(tokens)
236
 
237
+ return msgs_rates, tokens_infos
238
 
239
 
240
  class DecryptorProcessor(BaseProcessor):
schemes.py CHANGED
@@ -12,7 +12,7 @@ with open("resources/examples.json", "r") as f:
12
 
13
 
14
  class EncryptionBody(BaseModel):
15
- prompt: str = Field(title="Prompt used to generate text")
16
  msg: str = Field(title="Message wanted to hide")
17
  gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
18
  default=GlobalConfig.get("encrypt.default", "gen_model"),
 
12
 
13
 
14
  class EncryptionBody(BaseModel):
15
+ prompt: str | list[str] = Field(title="Prompt used to generate text")
16
  msg: str = Field(title="Message wanted to hide")
17
  gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
18
  default=GlobalConfig.get("encrypt.default", "gen_model"),
seed_schemes/dummy_hash.py CHANGED
@@ -8,5 +8,7 @@ class DummyHash(SeedScheme):
8
  pass
9
 
10
  def __call__(self, input_ids: torch.Tensor):
 
 
11
  return int(input_ids[-1].item())
12
 
 
8
  pass
9
 
10
  def __call__(self, input_ids: torch.Tensor):
11
+ if input_ids.size(0) == 0:
12
+ return 0
13
  return int(input_ids[-1].item())
14
 
stegno.py CHANGED
@@ -9,7 +9,7 @@ from processors import EncryptorLogitsProcessor, DecryptorProcessor
9
  def generate(
10
  tokenizer,
11
  model,
12
- prompt: str,
13
  msg: bytes,
14
  start_pos_p: list[int],
15
  delta: float,
@@ -20,9 +20,10 @@ def generate(
20
  private_key: Union[int, None] = None,
21
  min_new_tokens_ratio: float = 1,
22
  max_new_tokens_ratio: float = 2,
23
- num_beams: int = 4,
 
24
  repetition_penalty: float = 1.0,
25
- prompt_size: int = -1,
26
  ):
27
  """
28
  Generate the sequence containing the hidden data.
@@ -46,12 +47,13 @@ def generate(
46
  start_pos_p[0], start_pos_p[1] + 1, (1,)
47
  ).item()
48
  start_pos = int(start_pos) + window_length
 
 
 
 
49
 
50
- tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
51
- if prompt_size == -1:
52
- prompt_size = tokenized_input.input_ids.size(1)
53
  logits_processor = EncryptorLogitsProcessor(
54
- prompt_ids=tokenized_input.input_ids[:, :prompt_size],
55
  msg=msg,
56
  start_pos=start_pos,
57
  delta=delta,
@@ -77,29 +79,32 @@ def generate(
77
  max_length = min(max_length, tokenizer.model_max_length)
78
  min_length = min(min_length, max_length)
79
  output_tokens = model.generate(
80
- input_ids=tokenized_input.input_ids[:, :prompt_size],
81
- attention_mask=tokenized_input.attention_mask[:, :prompt_size],
82
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
83
  min_length=min_length,
84
  max_length=max_length,
85
- do_sample=True,
86
  num_beams=num_beams,
87
  repetition_penalty=float(repetition_penalty),
88
  pad_token_id=tokenizer.eos_token_id,
 
89
  )
90
 
91
  output_tokens = output_tokens[:, prompt_size:]
92
- output_text = tokenizer.batch_decode(
93
  output_tokens, skip_special_tokens=True
94
- )[0]
95
  output_tokens_post = tokenizer(
96
- output_text, return_tensors="pt", add_special_tokens=False
 
 
 
97
  ).to(model.device)
98
  msg_rates, tokens_infos = logits_processor.validate(
99
  output_tokens_post.input_ids
100
  )
101
 
102
- return output_text, msg_rates[0], tokens_infos[0]
103
 
104
 
105
  def decrypt(
 
9
  def generate(
10
  tokenizer,
11
  model,
12
+ prompt: str | list[str],
13
  msg: bytes,
14
  start_pos_p: list[int],
15
  delta: float,
 
20
  private_key: Union[int, None] = None,
21
  min_new_tokens_ratio: float = 1,
22
  max_new_tokens_ratio: float = 2,
23
+ do_sample: bool = True,
24
+ num_beams: int = 1,
25
  repetition_penalty: float = 1.0,
26
+ generator: torch.Generator | None = None,
27
  ):
28
  """
29
  Generate the sequence containing the hidden data.
 
47
  start_pos_p[0], start_pos_p[1] + 1, (1,)
48
  ).item()
49
  start_pos = int(start_pos) + window_length
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+
52
+ tokenized_input = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
53
+ prompt_size = tokenized_input.input_ids.size(1)
54
 
 
 
 
55
  logits_processor = EncryptorLogitsProcessor(
56
+ prompt_ids=tokenized_input.input_ids,
57
  msg=msg,
58
  start_pos=start_pos,
59
  delta=delta,
 
79
  max_length = min(max_length, tokenizer.model_max_length)
80
  min_length = min(min_length, max_length)
81
  output_tokens = model.generate(
82
+ **tokenized_input,
 
83
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
84
  min_length=min_length,
85
  max_length=max_length,
86
+ do_sample=do_sample,
87
  num_beams=num_beams,
88
  repetition_penalty=float(repetition_penalty),
89
  pad_token_id=tokenizer.eos_token_id,
90
+ generator=generator,
91
  )
92
 
93
  output_tokens = output_tokens[:, prompt_size:]
94
+ output_texts = tokenizer.batch_decode(
95
  output_tokens, skip_special_tokens=True
96
+ )
97
  output_tokens_post = tokenizer(
98
+ output_texts,
99
+ return_tensors="pt",
100
+ add_special_tokens=False,
101
+ padding=True,
102
  ).to(model.device)
103
  msg_rates, tokens_infos = logits_processor.validate(
104
  output_tokens_post.input_ids
105
  )
106
 
107
+ return output_texts, msg_rates, tokens_infos
108
 
109
 
110
  def decrypt(