tnk2908 commited on
Commit
11c7796
1 Parent(s): 0c3c1a0

Shows percentage of message hidden in the results; Improve the UI of the command line interface

Browse files
Files changed (3) hide show
  1. main.py +85 -27
  2. processors.py +17 -0
  3. stegno.py +9 -4
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from argparse import ArgumentParser
3
 
4
  import torch
@@ -51,18 +52,37 @@ def create_args():
51
  parser.add_argument(
52
  "--private-key", type=str, default="", help="Path to private key"
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Input
55
  parser.add_argument(
56
- "--msg", type=str, required=True, help="Path to file containing message"
 
 
 
57
  )
58
  parser.add_argument(
59
- "--prompt", type=str, default=None, help="Prompt used to generate text"
 
 
 
60
  )
61
  parser.add_argument(
62
  "--text",
63
  type=str,
64
  default=None,
65
- help="Text contains the hidden message",
66
  )
67
  # Mode
68
  parser.add_argument(
@@ -89,74 +109,109 @@ def main(args):
89
 
90
  if os.path.isfile(args.salt_key):
91
  with open(args.salt_key, "r") as f:
92
- salt_key = int(f.readline())
 
93
  else:
94
- salt_key = None
95
 
96
  if os.path.isfile(args.private_key):
97
  with open(args.private_key, "r") as f:
98
- private_key = int(f.readline())
 
99
  else:
100
- private_key = None
 
 
101
 
102
  if args.encrypt:
103
  if len(args.prompt) == 0:
104
  raise ValueError("Prompt cannot be empty in encrypt mode")
 
 
 
 
 
 
 
 
105
  if os.path.isfile(args.msg):
 
106
  with open(args.msg, "rb") as f:
107
- msg = f.read()
108
  else:
109
- raise ValueError(f"Message file {args.msg} is not a file")
110
 
111
  print("=" * os.get_terminal_size().columns)
112
  print("Encryption Parameters:")
113
  print(f" GenModel: {args.gen_model}")
114
- print(f" Prompt: {args.prompt}")
115
- print(f" Message: {msg}")
 
 
 
 
 
 
116
  print(f" Gamma: {args.gamma}")
117
  print(f" Message Base: {args.msg_base}")
118
  print(f" Seed Scheme: {args.seed_scheme}")
119
  print(f" Window Length: {args.window_length}")
120
- print(f" Salt Key: {salt_key}")
121
- print(f" Private Key: {private_key}")
 
 
122
  print("=" * os.get_terminal_size().columns)
123
- text = generate(
124
  tokenizer=tokenizer,
125
  model=model,
126
  prompt=args.prompt,
127
- msg=msg,
128
  gamma=args.gamma,
129
  msg_base=args.msg_base,
130
  seed_scheme=args.seed_scheme,
131
  window_length=args.window_length,
132
- salt_key=salt_key,
133
- private_key=private_key,
 
 
134
  )
135
- print(f"Text contains message:\n{text}")
136
-
137
- if os.path.isfile(args.save_file):
 
 
 
 
 
 
138
  with open(args.save_file, "w") as f:
139
  f.write(text)
140
-
141
- args.text = text
142
 
143
  if args.decrypt:
144
  if len(args.text) == 0:
145
  raise ValueError("Text cannot be empty in decrypt mode")
 
146
  if os.path.isfile(args.text):
 
147
  with open(args.text, "r") as f:
148
  lines = f.readlines()
149
  args.text = "".join(lines)
 
150
  print("=" * os.get_terminal_size().columns)
151
- print("Encryption Parameters:")
152
  print(f" GenModel: {args.gen_model}")
153
- print(f" Text: {args.text}")
154
  print(f" Message Base: {args.msg_base}")
155
  print(f" Seed Scheme: {args.seed_scheme}")
156
  print(f" Window Length: {args.window_length}")
157
- print(f" Salt Key: {salt_key}")
158
- print(f" Private Key: {private_key}")
 
 
 
 
159
  print("=" * os.get_terminal_size().columns)
 
160
  msgs = decrypt(
161
  tokenizer=tokenizer,
162
  device=args.device,
@@ -169,7 +224,10 @@ def main(args):
169
  )
170
  print("Message:")
171
  for s, msg in enumerate(msgs):
172
- print(f"Shift {s}: {msg}")
 
 
 
173
 
174
 
175
  if __name__ == "__main__":
 
1
  import os
2
+ import json
3
  from argparse import ArgumentParser
4
 
5
  import torch
 
52
  parser.add_argument(
53
  "--private-key", type=str, default="", help="Path to private key"
54
  )
55
+ # Generation Params
56
+ parser.add_argument(
57
+ "--num-beams",
58
+ type=int,
59
+ default=4,
60
+ help="Number of beams used in beam search",
61
+ )
62
+ parser.add_argument(
63
+ "--max-new-tokens-ratio",
64
+ type=float,
65
+ default=2,
66
+ help="Ratio of max new tokens to minimum tokens required to hide message",
67
+ )
68
  # Input
69
  parser.add_argument(
70
+ "--msg",
71
+ type=str,
72
+ default=None,
73
+ help="Message or path to message to be hidden",
74
  )
75
  parser.add_argument(
76
+ "--prompt",
77
+ type=str,
78
+ default=None,
79
+ help="Prompt or path to prompt used to generate text",
80
  )
81
  parser.add_argument(
82
  "--text",
83
  type=str,
84
  default=None,
85
+ help="Text or path to text containing the hidden message",
86
  )
87
  # Mode
88
  parser.add_argument(
 
109
 
110
  if os.path.isfile(args.salt_key):
111
  with open(args.salt_key, "r") as f:
112
+ args.salt_key = int(f.readline())
113
+ print(f"Read salt key from {args.salt_key}")
114
  else:
115
+ args.salt_key = int(args.salt_key) if len(args.salt_key) > 0 else None
116
 
117
  if os.path.isfile(args.private_key):
118
  with open(args.private_key, "r") as f:
119
+ args.private_key = int(f.readline())
120
+ print(f"Read private key from {args.private_key}")
121
  else:
122
+ args.private_key = (
123
+ int(args.private_key) if len(args.private_key) > 0 else None
124
+ )
125
 
126
  if args.encrypt:
127
  if len(args.prompt) == 0:
128
  raise ValueError("Prompt cannot be empty in encrypt mode")
129
+ if len(args.msg) == 0:
130
+ raise ValueError("Message cannot be empty in encrypt mode")
131
+
132
+ if os.path.isfile(args.prompt):
133
+ print(f"Read prompt from {args.prompt}")
134
+ with open(args.prompt, "r") as f:
135
+ args.prompt = "".join(f.readlines())
136
+
137
  if os.path.isfile(args.msg):
138
+ print(f"Read message from {args.msg}")
139
  with open(args.msg, "rb") as f:
140
+ args.msg = f.read()
141
  else:
142
+ args.msg = bytes(args.msg)
143
 
144
  print("=" * os.get_terminal_size().columns)
145
  print("Encryption Parameters:")
146
  print(f" GenModel: {args.gen_model}")
147
+ print(f" Prompt:")
148
+ print("- " * (os.get_terminal_size().columns // 2))
149
+ print(args.prompt)
150
+ print("- " * (os.get_terminal_size().columns // 2))
151
+ print(f" Message:")
152
+ print("- " * (os.get_terminal_size().columns // 2))
153
+ print(args.msg)
154
+ print("- " * (os.get_terminal_size().columns // 2))
155
  print(f" Gamma: {args.gamma}")
156
  print(f" Message Base: {args.msg_base}")
157
  print(f" Seed Scheme: {args.seed_scheme}")
158
  print(f" Window Length: {args.window_length}")
159
+ print(f" Salt Key: {args.salt_key}")
160
+ print(f" Private Key: {args.private_key}")
161
+ print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}")
162
+ print(f" Number of Beams: {args.num_beams}")
163
  print("=" * os.get_terminal_size().columns)
164
+ text, msg_rate = generate(
165
  tokenizer=tokenizer,
166
  model=model,
167
  prompt=args.prompt,
168
+ msg=args.msg,
169
  gamma=args.gamma,
170
  msg_base=args.msg_base,
171
  seed_scheme=args.seed_scheme,
172
  window_length=args.window_length,
173
+ salt_key=args.salt_key,
174
+ private_key=args.private_key,
175
+ max_new_tokens_ratio=args.max_new_tokens_ratio,
176
+ num_beams=args.num_beams,
177
  )
178
+ print(f"Text contains message:")
179
+ print("-" * (os.get_terminal_size().columns))
180
+ print(text)
181
+ print("-" * (os.get_terminal_size().columns))
182
+ print(f"Successfully hide {msg_rate*100:.2f} of the message")
183
+ print("-" * (os.get_terminal_size().columns))
184
+
185
+ if len(args.save_file) > 0:
186
+ os.makedirs(os.path.dirname(args.save_file), exist_ok=True)
187
  with open(args.save_file, "w") as f:
188
  f.write(text)
189
+ print(f"Saved result to {args.save_file}")
 
190
 
191
  if args.decrypt:
192
  if len(args.text) == 0:
193
  raise ValueError("Text cannot be empty in decrypt mode")
194
+
195
  if os.path.isfile(args.text):
196
+ print(f"Read text from {args.text}")
197
  with open(args.text, "r") as f:
198
  lines = f.readlines()
199
  args.text = "".join(lines)
200
+
201
  print("=" * os.get_terminal_size().columns)
202
+ print("Decryption Parameters:")
203
  print(f" GenModel: {args.gen_model}")
 
204
  print(f" Message Base: {args.msg_base}")
205
  print(f" Seed Scheme: {args.seed_scheme}")
206
  print(f" Window Length: {args.window_length}")
207
+ print(f" Salt Key: {args.salt_key}")
208
+ print(f" Private Key: {args.private_key}")
209
+ print(f" Text:")
210
+ print("- " * (os.get_terminal_size().columns // 2))
211
+ print(args.text)
212
+ print("- " * (os.get_terminal_size().columns // 2))
213
  print("=" * os.get_terminal_size().columns)
214
+
215
  msgs = decrypt(
216
  tokenizer=tokenizer,
217
  device=args.device,
 
224
  )
225
  print("Message:")
226
  for s, msg in enumerate(msgs):
227
+ print("-" * (os.get_terminal_size().columns))
228
+ print(f"Shift {s}: ")
229
+ print(msg[0])
230
+ print("-" * (os.get_terminal_size().columns))
231
 
232
 
233
  if __name__ == "__main__":
processors.py CHANGED
@@ -107,6 +107,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
107
  self.start_pos = []
108
  for i in range(prompt_ids.size(0)):
109
  self.start_pos.append(prompt_ids[i].size(0))
 
110
  self.msg = bytes_to_base(msg, self.msg_base)
111
  self.gamma = gamma
112
 
@@ -139,6 +140,22 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
139
  def get_message_len(self):
140
  return len(self.msg)
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  class DecryptorProcessor(BaseProcessor):
144
  def __init__(self, *args, **kwargs):
 
107
  self.start_pos = []
108
  for i in range(prompt_ids.size(0)):
109
  self.start_pos.append(prompt_ids[i].size(0))
110
+ self.raw_msg = msg
111
  self.msg = bytes_to_base(msg, self.msg_base)
112
  self.gamma = gamma
113
 
 
140
  def get_message_len(self):
141
  return len(self.msg)
142
 
143
+ def validate(self, input_ids_batch: torch.Tensor):
144
+ res = []
145
+ for input_ids in input_ids_batch:
146
+ values = []
147
+ for i in range(self.start_pos[0], input_ids.size(0)):
148
+ values.append(self._get_value(input_ids[: i + 1]))
149
+ enc_msg = base_to_bytes(values, self.msg_base)
150
+ cnt = 0
151
+ for i in range(len(self.raw_msg)):
152
+ if self.raw_msg[i] == enc_msg[i]:
153
+ cnt += 1
154
+ res.append(cnt / len(self.raw_msg))
155
+
156
+
157
+ return res
158
+
159
 
160
  class DecryptorProcessor(BaseProcessor):
161
  def __init__(self, *args, **kwargs):
stegno.py CHANGED
@@ -17,6 +17,8 @@ def generate(
17
  window_length: int = 1,
18
  salt_key: Union[int, None] = None,
19
  private_key: Union[int, None] = None,
 
 
20
  ):
21
  """
22
  Generate the sequence containing the hidden data.
@@ -51,15 +53,19 @@ def generate(
51
  **tokenized_input,
52
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
53
  min_new_tokens=logits_processor.get_message_len(),
54
- max_new_tokens=logits_processor.get_message_len() * 2,
 
 
55
  do_sample=True,
56
- num_beams=4,
57
  )
58
  output_text = tokenizer.batch_decode(
59
  output_tokens, skip_special_tokens=True
60
  )[0]
 
 
61
 
62
- return output_text
63
 
64
 
65
  def decrypt(
@@ -100,4 +106,3 @@ def decrypt(
100
  msg = decryptor.decrypt(tokenized_input.input_ids)
101
 
102
  return msg
103
-
 
17
  window_length: int = 1,
18
  salt_key: Union[int, None] = None,
19
  private_key: Union[int, None] = None,
20
+ max_new_tokens_ratio: float = 2,
21
+ num_beams: int = 4,
22
  ):
23
  """
24
  Generate the sequence containing the hidden data.
 
53
  **tokenized_input,
54
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
55
  min_new_tokens=logits_processor.get_message_len(),
56
+ max_new_tokens=int(
57
+ logits_processor.get_message_len() * max_new_tokens_ratio
58
+ ),
59
  do_sample=True,
60
+ num_beams=num_beams
61
  )
62
  output_text = tokenizer.batch_decode(
63
  output_tokens, skip_special_tokens=True
64
  )[0]
65
+ output_tokens_post = tokenizer(output_text, return_tensors="pt")
66
+ msg_rates = logits_processor.validate(output_tokens_post.input_ids)
67
 
68
+ return output_text, msg_rates[0]
69
 
70
 
71
  def decrypt(
 
106
  msg = decryptor.decrypt(tokenized_input.input_ids)
107
 
108
  return msg