tnk2908 commited on
Commit
cc8b2eb
1 Parent(s): f52f4a7

Add restrictions and rename parameters to meet the one written in report

Browse files
Files changed (7) hide show
  1. api.py +2 -2
  2. config.ini +2 -2
  3. demo.py +3 -3
  4. main.py +4 -4
  5. processors.py +5 -5
  6. schemes.py +77 -23
  7. stegno.py +4 -4
api.py CHANGED
@@ -25,7 +25,7 @@ async def encrypt_api(
25
  prompt=body.prompt,
26
  msg=str.encode(body.msg),
27
  start_pos_p=[body.start_pos],
28
- gamma=body.gamma,
29
  msg_base=body.msg_base,
30
  seed_scheme=body.seed_scheme,
31
  window_length=body.window_length,
@@ -64,7 +64,7 @@ async def default_config():
64
  "encrypt": {
65
  "gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
66
  "start_pos": GlobalConfig.get("encrypt.default", "start_pos"),
67
- "gamma": GlobalConfig.get("encrypt.default", "gamma"),
68
  "msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
69
  "seed_scheme": GlobalConfig.get(
70
  "encrypt.default", "seed_scheme"
 
25
  prompt=body.prompt,
26
  msg=str.encode(body.msg),
27
  start_pos_p=[body.start_pos],
28
+ delta=body.delta,
29
  msg_base=body.msg_base,
30
  seed_scheme=body.seed_scheme,
31
  window_length=body.window_length,
 
64
  "encrypt": {
65
  "gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
66
  "start_pos": GlobalConfig.get("encrypt.default", "start_pos"),
67
+ "delta": GlobalConfig.get("encrypt.default", "delta"),
68
  "msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
69
  "seed_scheme": GlobalConfig.get(
70
  "encrypt.default", "seed_scheme"
config.ini CHANGED
@@ -22,12 +22,12 @@ opt_13b = str:facebook/opt-13b
22
  [models.params]
23
  dtype = str:bfloat16
24
  load_device = str:cpu
25
- run_device = str:cuda
26
 
27
  [encrypt.default]
28
  gen_model = str:gpt2
29
  start_pos = int:0
30
- gamma = float:10.0
31
  msg_base = int:2
32
  seed_scheme = str:sha_left_hash
33
  window_length = int:1
 
22
  [models.params]
23
  dtype = str:bfloat16
24
  load_device = str:cpu
25
+ run_device = str:cpu
26
 
27
  [encrypt.default]
28
  gen_model = str:gpt2
29
  start_pos = int:0
30
+ delta = float:10.0
31
  msg_base = int:2
32
  seed_scheme = str:sha_left_hash
33
  window_length = int:1
demo.py CHANGED
@@ -12,7 +12,7 @@ def enc_fn(
12
  prompt: str,
13
  msg: str,
14
  start_pos: int,
15
- gamma: float,
16
  msg_base: int,
17
  seed_scheme: str,
18
  window_length: int,
@@ -28,7 +28,7 @@ def enc_fn(
28
  prompt=prompt,
29
  msg=str.encode(msg),
30
  start_pos_p=[start_pos],
31
- gamma=gamma,
32
  msg_base=msg_base,
33
  seed_scheme=seed_scheme,
34
  window_length=window_length,
@@ -98,7 +98,7 @@ if __name__ == "__main__":
98
  gr.Textbox(),
99
  gr.Textbox(),
100
  gr.Number(int(GlobalConfig.get("encrypt.default", "start_pos"))),
101
- gr.Number(float(GlobalConfig.get("encrypt.default", "gamma"))),
102
  gr.Number(int(GlobalConfig.get("encrypt.default", "msg_base"))),
103
  gr.Dropdown(
104
  value=GlobalConfig.get("encrypt.default", "seed_scheme"),
 
12
  prompt: str,
13
  msg: str,
14
  start_pos: int,
15
+ delta: float,
16
  msg_base: int,
17
  seed_scheme: str,
18
  window_length: int,
 
28
  prompt=prompt,
29
  msg=str.encode(msg),
30
  start_pos_p=[start_pos],
31
+ delta=delta,
32
  msg_base=msg_base,
33
  seed_scheme=seed_scheme,
34
  window_length=window_length,
 
98
  gr.Textbox(),
99
  gr.Textbox(),
100
  gr.Number(int(GlobalConfig.get("encrypt.default", "start_pos"))),
101
+ gr.Number(float(GlobalConfig.get("encrypt.default", "delta"))),
102
  gr.Number(int(GlobalConfig.get("encrypt.default", "msg_base"))),
103
  gr.Dropdown(
104
  value=GlobalConfig.get("encrypt.default", "seed_scheme"),
main.py CHANGED
@@ -25,9 +25,9 @@ def create_args():
25
  )
26
  # Stenography params
27
  parser.add_argument(
28
- "--gamma",
29
  type=float,
30
- default=GlobalConfig.get("encrypt.default", "gamma"),
31
  help="Bias added to scores of tokens in valid list",
32
  )
33
  parser.add_argument(
@@ -162,7 +162,7 @@ def main(args):
162
  print("- " * (os.get_terminal_size().columns // 2))
163
  print(args.msg)
164
  print("- " * (os.get_terminal_size().columns // 2))
165
- print(f" Gamma: {args.gamma}")
166
  print(f" Message Base: {args.msg_base}")
167
  print(f" Seed Scheme: {args.seed_scheme}")
168
  print(f" Window Length: {args.window_length}")
@@ -177,7 +177,7 @@ def main(args):
177
  prompt=args.prompt,
178
  msg=args.msg,
179
  start_pos_p=args.start_pos,
180
- gamma=args.gamma,
181
  msg_base=args.msg_base,
182
  seed_scheme=args.seed_scheme,
183
  window_length=args.window_length,
 
25
  )
26
  # Stenography params
27
  parser.add_argument(
28
+ "--delta",
29
  type=float,
30
+ default=GlobalConfig.get("encrypt.default", "delta"),
31
  help="Bias added to scores of tokens in valid list",
32
  )
33
  parser.add_argument(
 
162
  print("- " * (os.get_terminal_size().columns // 2))
163
  print(args.msg)
164
  print("- " * (os.get_terminal_size().columns // 2))
165
+ print(f" delta: {args.delta}")
166
  print(f" Message Base: {args.msg_base}")
167
  print(f" Seed Scheme: {args.seed_scheme}")
168
  print(f" Window Length: {args.window_length}")
 
177
  prompt=args.prompt,
178
  msg=args.msg,
179
  start_pos_p=args.start_pos,
180
+ delta=args.delta,
181
  msg_base=args.msg_base,
182
  seed_scheme=args.seed_scheme,
183
  window_length=args.window_length,
processors.py CHANGED
@@ -104,7 +104,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
104
  self,
105
  prompt_ids: torch.Tensor,
106
  msg: bytes,
107
- gamma: float,
108
  tokenizer,
109
  start_pos: int = 0,
110
  *args,
@@ -113,7 +113,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
113
  """
114
  Args:
115
  msg: message to hide in the text.
116
- gamma: bias add to scores of token in valid list.
117
  """
118
  super().__init__(*args, **kwargs)
119
  if prompt_ids.size(0) != 1:
@@ -126,7 +126,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
126
 
127
  self.raw_msg = msg
128
  self.msg = bytes_to_base(msg, self.msg_base)
129
- self.gamma = gamma
130
  self.tokenizer = tokenizer
131
  special_tokens = [
132
  tokenizer.bos_token_id,
@@ -158,13 +158,13 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
158
  self, input_ids: torch.Tensor, scores: torch.Tensor, value: int
159
  ):
160
  """
161
- Add the bias (gamma) to the valid list tokens
162
  """
163
  ids = torch.cat(
164
  [self._get_valid_list_ids(input_ids, value), self.special_tokens]
165
  )
166
 
167
- scores[ids] = scores[ids] + self.gamma
168
  return scores
169
 
170
  def get_message_len(self):
 
104
  self,
105
  prompt_ids: torch.Tensor,
106
  msg: bytes,
107
+ delta: float,
108
  tokenizer,
109
  start_pos: int = 0,
110
  *args,
 
113
  """
114
  Args:
115
  msg: message to hide in the text.
116
+ delta: bias add to scores of token in valid list.
117
  """
118
  super().__init__(*args, **kwargs)
119
  if prompt_ids.size(0) != 1:
 
126
 
127
  self.raw_msg = msg
128
  self.msg = bytes_to_base(msg, self.msg_base)
129
+ self.delta = delta
130
  self.tokenizer = tokenizer
131
  special_tokens = [
132
  tokenizer.bos_token_id,
 
158
  self, input_ids: torch.Tensor, scores: torch.Tensor, value: int
159
  ):
160
  """
161
+ Add the bias (delta) to the valid list tokens
162
  """
163
  ids = torch.cat(
164
  [self._get_valid_list_ids(input_ids, value), self.special_tokens]
165
  )
166
 
167
+ scores[ids] = scores[ids] + self.delta
168
  return scores
169
 
170
  def get_message_len(self):
schemes.py CHANGED
@@ -1,34 +1,88 @@
1
- from pydantic import BaseModel
2
  from global_config import GlobalConfig
 
 
 
3
 
4
 
5
  class EncryptionBody(BaseModel):
6
- prompt: str
7
- msg: str
8
- gen_model: str = GlobalConfig.get("encrypt.default", "gen_model")
9
- start_pos: int = GlobalConfig.get("encrypt.default", "start_pos")
 
 
 
 
 
 
 
10
 
11
- gamma: float = GlobalConfig.get("encrypt.default", "gamma")
12
- msg_base: int = GlobalConfig.get("encrypt.default", "msg_base")
 
 
 
 
 
 
 
 
13
 
14
- seed_scheme: str = GlobalConfig.get("encrypt.default", "seed_scheme")
15
- window_length: int = GlobalConfig.get(
16
- "encrypt.default", "window_length"
17
  )
18
- private_key: int = GlobalConfig.get("encrypt.default", "private_key")
19
- max_new_tokens_ratio: float = GlobalConfig.get(
20
- "encrypt.default", "max_new_tokens_ratio"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
- num_beams: int = GlobalConfig.get("encrypt.default", "num_beams")
23
- repetition_penalty: float = GlobalConfig.get('encrypt.default', "repetition_penalty")
24
 
25
- class DecryptionBody(BaseModel):
26
- text: str
27
- gen_model: str = GlobalConfig.get("decrypt.default", "gen_model")
28
- msg_base: int = GlobalConfig.get("decrypt.default", "msg_base")
29
 
30
- seed_scheme: str = GlobalConfig.get("decrypt.default", "seed_scheme")
31
- window_length: int = GlobalConfig.get(
32
- "decrypt.default", "window_length"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
- private_key: int = GlobalConfig.get("decrypt.default", "private_key")
 
1
+ from pydantic import BaseModel, Field
2
  from global_config import GlobalConfig
3
+ from model_factory import ModelFactory
4
+ from seed_scheme_factory import SeedSchemeFactory
5
+ from typing import Literal
6
 
7
 
8
  class EncryptionBody(BaseModel):
9
+ prompt: str = Field(title="Prompt used to generate text")
10
+ msg: str = Field(title="Message wanted to hide")
11
+ gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
12
+ default=GlobalConfig.get("encrypt.default", "gen_model"),
13
+ title="LLM used to generate text",
14
+ )
15
+ start_pos: int = Field(
16
+ default=GlobalConfig.get("encrypt.default", "start_pos"),
17
+ title="Start position to encrypt the message",
18
+ ge=0,
19
+ )
20
 
21
+ delta: float = Field(
22
+ default=GlobalConfig.get("encrypt.default", "delta"),
23
+ title="Hardness parameters",
24
+ gt=0,
25
+ )
26
+ msg_base: int = Field(
27
+ default=GlobalConfig.get("encrypt.default", "msg_base"),
28
+ title="Base of message used in base-encoding",
29
+ ge=2,
30
+ )
31
 
32
+ seed_scheme: Literal[tuple(SeedSchemeFactory.get_schemes_name())] = Field(
33
+ default=GlobalConfig.get("encrypt.default", "seed_scheme"),
34
+ title="Scheme used to compute seed for PRF",
35
  )
36
+ window_length: int = Field(
37
+ default=GlobalConfig.get("encrypt.default", "window_length"),
38
+ title="Window length (context size) used to compute the seed for PRF",
39
+ ge=1,
40
+ )
41
+ private_key: int = Field(
42
+ default=GlobalConfig.get("encrypt.default", "private_key"),
43
+ title="Private key used to compute the seed for PRF",
44
+ ge=0,
45
+ )
46
+ max_new_tokens_ratio: float = Field(
47
+ default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"),
48
+ title="Max length of generated text compared to the minimum length required to hide the message",
49
+ ge=1,
50
+ )
51
+ num_beams: int = Field(
52
+ default=GlobalConfig.get("encrypt.default", "num_beams"),
53
+ title="Number of beams used in beam search",
54
+ ge=1,
55
+ )
56
+
57
+ repetition_penalty: float = Field(
58
+ default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
59
+ title="Penalty used to avoid repetition when sampling tokens",
60
+ ge=1,
61
  )
 
 
62
 
 
 
 
 
63
 
64
+ class DecryptionBody(BaseModel):
65
+ text: str = Field(title="Text containing the message")
66
+ gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
67
+ default=GlobalConfig.get("decrypt.default", "gen_model"),
68
+ title="LLM used to generate text",
69
+ )
70
+ msg_base: int = Field(
71
+ default=GlobalConfig.get("decrypt.default", "msg_base"),
72
+ title="Base of message used in base-encoding",
73
+ ge=2,
74
+ )
75
+ seed_scheme: Literal[tuple(SeedSchemeFactory.get_schemes_name())] = Field(
76
+ default=GlobalConfig.get("decrypt.default", "seed_scheme"),
77
+ title="Scheme used to compute seed for PRF",
78
+ )
79
+ window_length: int = Field(
80
+ default=GlobalConfig.get("decrypt.default", "window_length"),
81
+ title="Window length (context size) used to compute the seed for PRF",
82
+ ge=1,
83
+ )
84
+ private_key: int = Field(
85
+ default=GlobalConfig.get("decrypt.default", "private_key"),
86
+ title="Private key used to compute the seed for PRF",
87
+ ge=0,
88
  )
 
stegno.py CHANGED
@@ -12,7 +12,7 @@ def generate(
12
  prompt: str,
13
  msg: bytes,
14
  start_pos_p: list[int],
15
- gamma: float,
16
  msg_base: int,
17
  seed_scheme: str,
18
  window_length: int = 1,
@@ -30,7 +30,7 @@ def generate(
30
  model: generative model to use.
31
  prompt: input prompt.
32
  msg: message to hide in the text.
33
- gamma: bias add to scores of token in valid list.
34
  msg_base: base of the message.
35
  seed_scheme: scheme used to compute the seed.
36
  window_length: length of window to compute the seed.
@@ -52,7 +52,7 @@ def generate(
52
  prompt_ids=tokenized_input.input_ids,
53
  msg=msg,
54
  start_pos=start_pos,
55
- gamma=gamma,
56
  msg_base=msg_base,
57
  vocab=list(tokenizer.get_vocab().values()),
58
  tokenizer=tokenizer,
@@ -107,7 +107,7 @@ def decrypt(
107
  tokenizer: tokenizer to use.
108
  text: text to decode.
109
  msg_base: base of the message.
110
- gamma: bias added to scores of valid list.
111
  seed_scheme: scheme used to compute the seed.
112
  window_length: length of window to compute the seed.
113
  salt_key: salt to add to the seed.
 
12
  prompt: str,
13
  msg: bytes,
14
  start_pos_p: list[int],
15
+ delta: float,
16
  msg_base: int,
17
  seed_scheme: str,
18
  window_length: int = 1,
 
30
  model: generative model to use.
31
  prompt: input prompt.
32
  msg: message to hide in the text.
33
+ delta: bias add to scores of token in valid list.
34
  msg_base: base of the message.
35
  seed_scheme: scheme used to compute the seed.
36
  window_length: length of window to compute the seed.
 
52
  prompt_ids=tokenized_input.input_ids,
53
  msg=msg,
54
  start_pos=start_pos,
55
+ delta=delta,
56
  msg_base=msg_base,
57
  vocab=list(tokenizer.get_vocab().values()),
58
  tokenizer=tokenizer,
 
107
  tokenizer: tokenizer to use.
108
  text: text to decode.
109
  msg_base: base of the message.
110
+ delta: bias added to scores of valid list.
111
  seed_scheme: scheme used to compute the seed.
112
  window_length: length of window to compute the seed.
113
  salt_key: salt to add to the seed.