tnk2908 commited on
Commit
7e8d9b9
1 Parent(s): 91d5a5e

Allow users to change the start_pos

Browse files
Files changed (6) hide show
  1. main.py +10 -1
  2. processors.py +29 -16
  3. requirements.txt +1 -0
  4. seed_schemes.py +0 -39
  5. stegno.py +22 -5
  6. test.py +14 -0
main.py CHANGED
@@ -84,6 +84,14 @@ def create_args():
84
  default=None,
85
  help="Text or path to text containing the hidden message",
86
  )
 
 
 
 
 
 
 
 
87
  # Mode
88
  parser.add_argument(
89
  "--encrypt",
@@ -166,6 +174,7 @@ def main(args):
166
  model=model,
167
  prompt=args.prompt,
168
  msg=args.msg,
 
169
  gamma=args.gamma,
170
  msg_base=args.msg_base,
171
  seed_scheme=args.seed_scheme,
@@ -179,7 +188,7 @@ def main(args):
179
  print("-" * (os.get_terminal_size().columns))
180
  print(text)
181
  print("-" * (os.get_terminal_size().columns))
182
- print(f"Successfully hide {msg_rate*100:.2f} of the message")
183
  print("-" * (os.get_terminal_size().columns))
184
 
185
  if len(args.save_file) > 0:
 
84
  default=None,
85
  help="Text or path to text containing the hidden message",
86
  )
87
+ # Encryption params
88
+ parser.add_argument(
89
+ "--start-pos",
90
+ type=int,
91
+ nargs="+",
92
+ default=[0],
93
+ help="Start position to input the text (not including window length). If 2 integers are provided, choose the position randomly between the two values.",
94
+ )
95
  # Mode
96
  parser.add_argument(
97
  "--encrypt",
 
174
  model=model,
175
  prompt=args.prompt,
176
  msg=args.msg,
177
+ start_pos_p=args.start_pos,
178
  gamma=args.gamma,
179
  msg_base=args.msg_base,
180
  seed_scheme=args.seed_scheme,
 
188
  print("-" * (os.get_terminal_size().columns))
189
  print(text)
190
  print("-" * (os.get_terminal_size().columns))
191
+ print(f"Successfully hide {msg_rate*100:.2f}% of the message")
192
  print("-" * (os.get_terminal_size().columns))
193
 
194
  if len(args.save_file) > 0:
processors.py CHANGED
@@ -4,7 +4,7 @@ from typing import Union
4
  import torch
5
  from transformers import LogitsProcessor
6
 
7
- from seed_schemes import seed_scheme_factory
8
  from utils import bytes_to_base, base_to_bytes, get_values_per_byte
9
 
10
 
@@ -36,15 +36,20 @@ class BaseProcessor(object):
36
  self.device = device
37
 
38
  # Seed parameters
39
- self.seed_fn = seed_scheme_factory.get(
40
  seed_scheme,
41
  salt_key=salt_key,
42
  private_key=private_key,
43
  )
 
 
 
 
 
44
  self.window_length = window_length
45
 
46
- # Initialize RNG
47
- self.rng = torch.Generator(device=device)
48
 
49
  # Compute the ranges of each value in base
50
  self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64)
@@ -69,7 +74,9 @@ class BaseProcessor(object):
69
  Get ids of tokens in the valid list for the current sequences.
70
  """
71
  self._seed_rng(input_ids)
72
- vocab_perm = torch.randperm(self.vocab_size, generator=self.rng)
 
 
73
  vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]]
74
 
75
  return vocab_list
@@ -79,7 +86,9 @@ class BaseProcessor(object):
79
  Check whether the token is in the valid list.
80
  """
81
  self._seed_rng(input_ids[:-1])
82
- vocab_perm = torch.randperm(self.vocab_size, generator=self.rng)
 
 
83
 
84
  cur_token = input_ids[-1]
85
  cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0]
@@ -94,8 +103,9 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
94
  prompt_ids: torch.Tensor,
95
  msg: bytes,
96
  gamma: float,
 
97
  *args,
98
- **kwargs
99
  ):
100
  """
101
  Args:
@@ -103,10 +113,14 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
103
  gamma: bias add to scores of token in valid list.
104
  """
105
  super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
106
 
107
- self.start_pos = []
108
- for i in range(prompt_ids.size(0)):
109
- self.start_pos.append(prompt_ids[i].size(0))
110
  self.raw_msg = msg
111
  self.msg = bytes_to_base(msg, self.msg_base)
112
  self.gamma = gamma
@@ -118,8 +132,8 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
118
 
119
  for i, input_ids in enumerate(input_ids_batch):
120
  cur_pos = input_ids.size(0)
121
- msg_ptr = cur_pos - self.start_pos[0]
122
- if msg_ptr >= len(self.msg):
123
  continue
124
  scores_batch[i] = self._add_bias_to_valid_list(
125
  input_ids, scores_batch[i], self.msg[msg_ptr]
@@ -144,7 +158,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
144
  res = []
145
  for input_ids in input_ids_batch:
146
  values = []
147
- for i in range(self.start_pos[0], input_ids.size(0)):
148
  values.append(self._get_value(input_ids[: i + 1]))
149
  enc_msg = base_to_bytes(values, self.msg_base)
150
  cnt = 0
@@ -152,7 +166,6 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
152
  if self.raw_msg[i] == enc_msg[i]:
153
  cnt += 1
154
  res.append(cnt / len(self.raw_msg))
155
-
156
 
157
  return res
158
 
@@ -166,12 +179,12 @@ class DecryptorProcessor(BaseProcessor):
166
  Decrypt the text sequences.
167
  """
168
  shift_msg = []
169
- for s in range(get_values_per_byte(self.msg_base)):
170
  msg = []
171
  bytes_msg = []
172
  for i, input_ids in enumerate(input_ids_batch):
173
  msg.append(list())
174
- for j in range(self.window_length + s, len(input_ids)):
175
  # TODO: this could be slow. Considering reimplement this.
176
  value = self._get_value(input_ids[: j + 1])
177
  msg[i].append(value)
 
4
  import torch
5
  from transformers import LogitsProcessor
6
 
7
+ from seed_scheme_factory import SeedSchemeFactory
8
  from utils import bytes_to_base, base_to_bytes, get_values_per_byte
9
 
10
 
 
36
  self.device = device
37
 
38
  # Seed parameters
39
+ seed_fn = SeedSchemeFactory.get_instance(
40
  seed_scheme,
41
  salt_key=salt_key,
42
  private_key=private_key,
43
  )
44
+ if seed_fn is None:
45
+ raise ValueError(f'Seed scheme "{seed_scheme}" is invalid')
46
+ else:
47
+ self.seed_fn = seed_fn
48
+
49
  self.window_length = window_length
50
 
51
+ # Initialize RNG, always use cpu generator
52
+ self.rng = torch.Generator(device="cpu")
53
 
54
  # Compute the ranges of each value in base
55
  self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64)
 
74
  Get ids of tokens in the valid list for the current sequences.
75
  """
76
  self._seed_rng(input_ids)
77
+ vocab_perm = torch.randperm(
78
+ self.vocab_size, generator=self.rng, device="cpu"
79
+ ).to(self.device)
80
  vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]]
81
 
82
  return vocab_list
 
86
  Check whether the token is in the valid list.
87
  """
88
  self._seed_rng(input_ids[:-1])
89
+ vocab_perm = torch.randperm(
90
+ self.vocab_size, generator=self.rng, device="cpu"
91
+ ).to(self.device)
92
 
93
  cur_token = input_ids[-1]
94
  cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0]
 
103
  prompt_ids: torch.Tensor,
104
  msg: bytes,
105
  gamma: float,
106
+ start_pos: int = 0,
107
  *args,
108
+ **kwargs,
109
  ):
110
  """
111
  Args:
 
113
  gamma: bias add to scores of token in valid list.
114
  """
115
  super().__init__(*args, **kwargs)
116
+ if prompt_ids.size(0) != 1:
117
+ raise RuntimeError(
118
+ "EncryptorLogitsProcessor does not support multiple prompts input."
119
+ )
120
+
121
+ self.prompt_size = prompt_ids.size(1)
122
+ self.start_pos = start_pos
123
 
 
 
 
124
  self.raw_msg = msg
125
  self.msg = bytes_to_base(msg, self.msg_base)
126
  self.gamma = gamma
 
132
 
133
  for i, input_ids in enumerate(input_ids_batch):
134
  cur_pos = input_ids.size(0)
135
+ msg_ptr = cur_pos - (self.prompt_size + self.start_pos)
136
+ if msg_ptr < 0 or msg_ptr >= len(self.msg):
137
  continue
138
  scores_batch[i] = self._add_bias_to_valid_list(
139
  input_ids, scores_batch[i], self.msg[msg_ptr]
 
158
  res = []
159
  for input_ids in input_ids_batch:
160
  values = []
161
+ for i in range(self.start_pos, input_ids.size(0)):
162
  values.append(self._get_value(input_ids[: i + 1]))
163
  enc_msg = base_to_bytes(values, self.msg_base)
164
  cnt = 0
 
166
  if self.raw_msg[i] == enc_msg[i]:
167
  cnt += 1
168
  res.append(cnt / len(self.raw_msg))
 
169
 
170
  return res
171
 
 
179
  Decrypt the text sequences.
180
  """
181
  shift_msg = []
182
+ for shift in range(get_values_per_byte(self.msg_base)):
183
  msg = []
184
  bytes_msg = []
185
  for i, input_ids in enumerate(input_ids_batch):
186
  msg.append(list())
187
+ for j in range(self.window_length + shift, len(input_ids)):
188
  # TODO: this could be slow. Considering reimplement this.
189
  value = self._get_value(input_ids[: j + 1])
190
  msg[i].append(value)
requirements.txt CHANGED
@@ -5,3 +5,4 @@ PyYAML==6.0.1
5
  scikit-learn==1.5.0
6
  torch==2.3.0
7
  cryptography==42.0.8
 
 
5
  scikit-learn==1.5.0
6
  torch==2.3.0
7
  cryptography==42.0.8
8
+ fastapi
seed_schemes.py DELETED
@@ -1,39 +0,0 @@
1
- from typing import Union, Callable
2
-
3
- import torch
4
-
5
-
6
- class SeedSchemeFactory:
7
- def __init__(self):
8
- self.seed_scheme_dict = dict()
9
-
10
- def register(self, name: str, seed_scheme: type):
11
- """
12
- Register the hash scheme by name. Hash scheme must be callable.
13
-
14
- Args:
15
- name: name of seed scheme.
16
- func: seed function.
17
- """
18
- self.seed_scheme_dict[name] = seed_scheme
19
-
20
- def get(self, name: str, **kwargs):
21
- """
22
- Get the hash scheme by name.
23
-
24
- Args:
25
- name: name of seed scheme.
26
- """
27
- return self.seed_scheme_dict[name](**kwargs)
28
-
29
-
30
- class DummyHash:
31
- def __init__(self, *args, **kwargs):
32
- pass
33
-
34
- def __call__(self, input_ids: torch.Tensor):
35
- return input_ids[-1].item()
36
-
37
-
38
- seed_scheme_factory = SeedSchemeFactory()
39
- seed_scheme_factory.register("dummy_hash", DummyHash)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stegno.py CHANGED
@@ -11,6 +11,7 @@ def generate(
11
  model,
12
  prompt: str,
13
  msg: bytes,
 
14
  gamma: float,
15
  msg_base: int,
16
  seed_scheme: str,
@@ -36,10 +37,20 @@ def generate(
36
  private_key: private key used to compute the seed.
37
 
38
  """
 
 
 
 
 
 
 
 
39
  tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
 
40
  logits_processor = EncryptorLogitsProcessor(
41
  prompt_ids=tokenized_input.input_ids,
42
  msg=msg,
 
43
  gamma=gamma,
44
  msg_base=msg_base,
45
  vocab=list(tokenizer.get_vocab().values()),
@@ -49,20 +60,26 @@ def generate(
49
  salt_key=salt_key,
50
  private_key=private_key,
51
  )
 
 
 
 
52
  output_tokens = model.generate(
53
  **tokenized_input,
54
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
55
- min_new_tokens=logits_processor.get_message_len(),
56
- max_new_tokens=int(
57
- logits_processor.get_message_len() * max_new_tokens_ratio
58
- ),
59
  do_sample=True,
60
  num_beams=num_beams
61
  )
 
 
62
  output_text = tokenizer.batch_decode(
63
  output_tokens, skip_special_tokens=True
64
  )[0]
65
- output_tokens_post = tokenizer(output_text, return_tensors="pt")
 
 
66
  msg_rates = logits_processor.validate(output_tokens_post.input_ids)
67
 
68
  return output_text, msg_rates[0]
 
11
  model,
12
  prompt: str,
13
  msg: bytes,
14
+ start_pos_p: list[int],
15
  gamma: float,
16
  msg_base: int,
17
  seed_scheme: str,
 
37
  private_key: private key used to compute the seed.
38
 
39
  """
40
+ if len(start_pos_p) == 1:
41
+ start_pos = start_pos_p[0]
42
+ else:
43
+ start_pos = torch.randint(
44
+ start_pos_p[0], start_pos_p[1] + 1, (1,)
45
+ ).item()
46
+ start_pos = int(start_pos) + window_length
47
+
48
  tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
49
+ prompt_size = tokenized_input.input_ids.size(1)
50
  logits_processor = EncryptorLogitsProcessor(
51
  prompt_ids=tokenized_input.input_ids,
52
  msg=msg,
53
+ start_pos=start_pos,
54
  gamma=gamma,
55
  msg_base=msg_base,
56
  vocab=list(tokenizer.get_vocab().values()),
 
60
  salt_key=salt_key,
61
  private_key=private_key,
62
  )
63
+ min_length = start_pos + logits_processor.get_message_len()
64
+ max_length = int(
65
+ start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
66
+ )
67
  output_tokens = model.generate(
68
  **tokenized_input,
69
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
70
+ min_new_tokens=min_length,
71
+ max_new_tokens=max_length,
 
 
72
  do_sample=True,
73
  num_beams=num_beams
74
  )
75
+
76
+ output_tokens = output_tokens[:, prompt_size:]
77
  output_text = tokenizer.batch_decode(
78
  output_tokens, skip_special_tokens=True
79
  )[0]
80
+ output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
81
+ model.device
82
+ )
83
  msg_rates = logits_processor.validate(output_tokens_post.input_ids)
84
 
85
  return output_text, msg_rates[0]
test.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from seed_scheme_factory import SeedSchemeFactory
2
+ import torch
3
+
4
+ if __name__ == "__main__":
5
+ seed_fn = SeedSchemeFactory.get_instance("sha_left_hash", private_key=18)
6
+ rng = torch.Generator()
7
+ rng.manual_seed(1)
8
+ input_ids = torch.randint(0, 2**32, (8,), generator=rng)
9
+ print("input_ids =", input_ids)
10
+ if seed_fn is not None:
11
+ seed = seed_fn(input_ids)
12
+ print(" ", 2**64-1)
13
+ print("seed =", seed)
14
+