tnk2908 commited on
Commit
1f125f1
1 Parent(s): 72a1159

Integrate global config

Browse files
Files changed (8) hide show
  1. api.py +88 -45
  2. demo.py +60 -39
  3. main.py +10 -8
  4. processors.py +17 -2
  5. requirements.txt +1 -0
  6. schemes.py +33 -0
  7. stegno.py +1 -0
  8. utils.py +5 -0
api.py CHANGED
@@ -1,66 +1,109 @@
 
 
1
  import torch
2
  from fastapi import FastAPI
 
3
 
4
  from stegno import generate, decrypt
5
  from utils import load_model
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
 
10
- @app.get("/encrypt")
11
- async def encrypt(
12
- prompt: str,
13
- msg: str,
14
- gen_model: str = "openai-community/gpt2",
15
- device: str = "cpu",
16
- start_pos: int = 0,
17
- gamma: float = 2.0,
18
- msg_base: int = 2,
19
- seed_scheme: str = "dummy_hash",
20
- window_length: int = 1,
21
- private_key: int = 0,
22
- max_new_tokens_ratio: float = 2,
23
- num_beams: int = 4,
24
  ):
25
- model, tokenizer = load_model(gen_model, torch.device(device))
26
  text, msg_rate = generate(
27
  tokenizer=tokenizer,
28
  model=model,
29
- prompt=prompt,
30
- msg=str.encode(msg),
31
- start_pos_p=[start_pos],
32
- gamma=gamma,
33
- msg_base=msg_base,
34
- seed_scheme=seed_scheme,
35
- window_length=window_length,
36
- private_key=private_key,
37
- max_new_tokens_ratio=max_new_tokens_ratio,
38
- num_beams=num_beams,
39
  )
40
  return {"text": text, "msg_rate": msg_rate}
41
 
42
 
43
- @app.get("/decrypt")
44
- async def dec(
45
- text: str,
46
- gen_model: str = "openai-community/gpt2",
47
- device: str = "cpu",
48
- msg_base: int = 2,
49
- seed_scheme: str = "dummy_hash",
50
- window_length: int = 1,
51
- private_key: int = 0,
52
- ):
53
- model, tokenizer = load_model(gen_model, torch.device(device))
54
  msgs = decrypt(
55
  tokenizer=tokenizer,
56
  device=model.device,
57
- text=text,
58
- msg_base=msg_base,
59
- seed_scheme=seed_scheme,
60
- window_length=window_length,
61
- private_key=private_key,
62
  )
63
- msg_text = ""
64
- for i, msg in enumerate(msgs):
65
- msg_text += f"Shift {i}: {msg}\n\n"
66
- return {"msg": msg_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+
3
  import torch
4
  from fastapi import FastAPI
5
+ import uvicorn
6
 
7
  from stegno import generate, decrypt
8
  from utils import load_model
9
+ from seed_scheme_factory import SeedSchemeFactory
10
+ from model_factory import ModelFactory
11
+ from global_config import GlobalConfig
12
+ from schemes import DecryptionBody, EncryptionBody
13
 
14
  app = FastAPI()
15
 
16
 
17
+ @app.post("/encrypt")
18
+ async def encrypt_api(
19
+ body: EncryptionBody,
 
 
 
 
 
 
 
 
 
 
 
20
  ):
21
+ model, tokenizer = ModelFactory.load_model(body.gen_model)
22
  text, msg_rate = generate(
23
  tokenizer=tokenizer,
24
  model=model,
25
+ prompt=body.prompt,
26
+ msg=str.encode(body.msg),
27
+ start_pos_p=[body.start_pos],
28
+ gamma=body.gamma,
29
+ msg_base=body.msg_base,
30
+ seed_scheme=body.seed_scheme,
31
+ window_length=body.window_length,
32
+ private_key=body.private_key,
33
+ max_new_tokens_ratio=body.max_new_tokens_ratio,
34
+ num_beams=body.num_beams,
35
  )
36
  return {"text": text, "msg_rate": msg_rate}
37
 
38
 
39
+ @app.post("/decrypt")
40
+ async def decrypt_api(body: DecryptionBody):
41
+ model, tokenizer = ModelFactory.load_model(body.gen_model)
 
 
 
 
 
 
 
 
42
  msgs = decrypt(
43
  tokenizer=tokenizer,
44
  device=model.device,
45
+ text=body.text,
46
+ msg_base=body.msg_base,
47
+ seed_scheme=body.seed_scheme,
48
+ window_length=body.window_length,
49
+ private_key=body.private_key,
50
  )
51
+ msg_b64 = {}
52
+ for i, s_msg in enumerate(msgs):
53
+ msg_b64[i] = []
54
+ for msg in s_msg:
55
+ msg_b64[i].append(base64.b64encode(msg))
56
+ return msg_b64
57
+
58
+
59
+ @app.get("/configs")
60
+ async def default_config():
61
+ configs = {
62
+ "default": {
63
+ "encrypt": {
64
+ "gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
65
+ "start_pos": GlobalConfig.get("encrypt.default", "start_pos"),
66
+ "gamma": GlobalConfig.get("encrypt.default", "gamma"),
67
+ "msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
68
+ "seed_scheme": GlobalConfig.get(
69
+ "encrypt.default", "seed_scheme"
70
+ ),
71
+ "window_length": GlobalConfig.get(
72
+ "encrypt.default", "window_length"
73
+ ),
74
+ "private_key": GlobalConfig.get(
75
+ "encrypt.default", "private_key"
76
+ ),
77
+ "max_new_tokens_ratio": GlobalConfig.get(
78
+ "encrypt.default", "max_new_tokens_ratio"
79
+ ),
80
+ "num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
81
+ },
82
+ "decrypt": {
83
+ "gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
84
+ "msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
85
+ "seed_scheme": GlobalConfig.get(
86
+ "encrypt.default", "seed_scheme"
87
+ ),
88
+ "window_length": GlobalConfig.get(
89
+ "encrypt.default", "window_length"
90
+ ),
91
+ "private_key": GlobalConfig.get(
92
+ "encrypt.default", "private_key"
93
+ ),
94
+ },
95
+ },
96
+ "seed_schemes": SeedSchemeFactory.get_schemes_name(),
97
+ "models": ModelFactory.get_models_names(),
98
+ }
99
+
100
+ return configs
101
+
102
+
103
+ if __name__ == "__main__":
104
+ port = GlobalConfig.get("server", "port")
105
+ if port is None:
106
+ port = 8000
107
+ else:
108
+ port = int(port)
109
+ uvicorn.run(app, host="0.0.0.0", port=port)
demo.py CHANGED
@@ -1,26 +1,26 @@
1
  import torch
2
  import gradio as gr
3
 
4
- from utils import load_model
5
  from stegno import generate, decrypt
6
  from seed_scheme_factory import SeedSchemeFactory
 
7
 
8
 
9
  def enc_fn(
10
- gen_model: str = "openai-community/gpt2",
11
- device: str = "cpu",
12
- prompt: str = "",
13
- msg: str = "",
14
- start_pos: int = 0,
15
- gamma: float = 2.0,
16
- msg_base: int = 2,
17
- seed_scheme: str = "dummy_hash",
18
- window_length: int = 1,
19
- private_key: int = 0,
20
- max_new_tokens_ratio: float = 2,
21
- num_beams: int = 4,
22
  ):
23
- model, tokenizer = load_model(gen_model, torch.device(device))
24
  text, msg_rate = generate(
25
  tokenizer=tokenizer,
26
  model=model,
@@ -39,15 +39,14 @@ def enc_fn(
39
 
40
 
41
  def dec_fn(
42
- gen_model: str = "openai-community/gpt2",
43
- device: str = "cpu",
44
- text: str = "",
45
- msg_base: int = 2,
46
- seed_scheme: str = "dummy_hash",
47
- window_length: int = 1,
48
- private_key: int = 0,
49
  ):
50
- model, tokenizer = load_model(gen_model, torch.device(device))
51
  msgs = decrypt(
52
  tokenizer=tokenizer,
53
  device=model.device,
@@ -67,34 +66,56 @@ if __name__ == "__main__":
67
  enc = gr.Interface(
68
  fn=enc_fn,
69
  inputs=[
70
- gr.Textbox("openai-community/gpt2"),
71
- gr.Textbox("cpu"),
 
 
72
  gr.Textbox(),
73
  gr.Textbox(),
74
- gr.Number(),
75
- gr.Number(10.0),
76
- gr.Number(2),
77
- gr.Dropdown(value="dummy_hash", choices=SeedSchemeFactory.get_schemes_name()),
78
- gr.Number(1),
79
- gr.Number(),
80
- gr.Number(2),
81
- gr.Number(4),
 
 
 
 
 
 
 
 
 
82
  ],
83
  outputs=[
84
- gr.Textbox(label="Text containing message", show_label=True, show_copy_button=True),
 
 
 
 
85
  gr.Number(label="Percentage of message in text", show_label=True),
86
  ],
87
  )
88
  dec = gr.Interface(
89
  fn=dec_fn,
90
  inputs=[
91
- gr.Textbox("openai-community/gpt2"),
92
- gr.Textbox("cpu"),
 
 
93
  gr.Textbox(),
94
- gr.Number(2),
95
- gr.Dropdown(value="dummy_hash", choices=SeedSchemeFactory.get_schemes_name()),
96
- gr.Number(1),
97
- gr.Number(),
 
 
 
 
 
98
  ],
99
  outputs=[
100
  gr.Textbox(label="Message", show_label=True),
 
1
  import torch
2
  import gradio as gr
3
 
4
+ from model_factory import ModelFactory
5
  from stegno import generate, decrypt
6
  from seed_scheme_factory import SeedSchemeFactory
7
+ from global_config import GlobalConfig
8
 
9
 
10
  def enc_fn(
11
+ gen_model: str,
12
+ prompt: str,
13
+ msg: str,
14
+ start_pos: int,
15
+ gamma: float,
16
+ msg_base: int,
17
+ seed_scheme: str,
18
+ window_length: int,
19
+ private_key: int,
20
+ max_new_tokens_ratio: float,
21
+ num_beams: int,
 
22
  ):
23
+ model, tokenizer = ModelFactory.load_model(gen_model)
24
  text, msg_rate = generate(
25
  tokenizer=tokenizer,
26
  model=model,
 
39
 
40
 
41
  def dec_fn(
42
+ gen_model: str,
43
+ text: str,
44
+ msg_base: int,
45
+ seed_scheme: str,
46
+ window_length: int,
47
+ private_key: int,
 
48
  ):
49
+ model, tokenizer = ModelFactory.load_model(gen_model)
50
  msgs = decrypt(
51
  tokenizer=tokenizer,
52
  device=model.device,
 
66
  enc = gr.Interface(
67
  fn=enc_fn,
68
  inputs=[
69
+ gr.Dropdown(
70
+ value=GlobalConfig.get("encrypt.default", "gen_model"),
71
+ choices=ModelFactory.get_models_names(),
72
+ ),
73
  gr.Textbox(),
74
  gr.Textbox(),
75
+ gr.Number(int(GlobalConfig.get("encrypt.default", "start_pos"))),
76
+ gr.Number(float(GlobalConfig.get("encrypt.default", "gamma"))),
77
+ gr.Number(int(GlobalConfig.get("encrypt.default", "msg_base"))),
78
+ gr.Dropdown(
79
+ value=GlobalConfig.get("encrypt.default", "seed_scheme"),
80
+ choices=SeedSchemeFactory.get_schemes_name(),
81
+ ),
82
+ gr.Number(
83
+ int(GlobalConfig.get("encrypt.default", "window_length"))
84
+ ),
85
+ gr.Number(int(GlobalConfig.get("encrypt.default", "private_key"))),
86
+ gr.Number(
87
+ float(
88
+ GlobalConfig.get("encrypt.default", "max_new_tokens_ratio")
89
+ )
90
+ ),
91
+ gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
92
  ],
93
  outputs=[
94
+ gr.Textbox(
95
+ label="Text containing message",
96
+ show_label=True,
97
+ show_copy_button=True,
98
+ ),
99
  gr.Number(label="Percentage of message in text", show_label=True),
100
  ],
101
  )
102
  dec = gr.Interface(
103
  fn=dec_fn,
104
  inputs=[
105
+ gr.Dropdown(
106
+ value=GlobalConfig.get("decrypt.default", "gen_model"),
107
+ choices=ModelFactory.get_models_names(),
108
+ ),
109
  gr.Textbox(),
110
+ gr.Number(int(GlobalConfig.get("decrypt.default", "msg_base"))),
111
+ gr.Dropdown(
112
+ value=GlobalConfig.get("decrypt.default", "seed_scheme"),
113
+ choices=SeedSchemeFactory.get_schemes_name(),
114
+ ),
115
+ gr.Number(
116
+ int(GlobalConfig.get("decrypt.default", "window_length"))
117
+ ),
118
+ gr.Number(int(GlobalConfig.get("decrypt.default", "private_key"))),
119
  ],
120
  outputs=[
121
  gr.Textbox(label="Message", show_label=True),
main.py CHANGED
@@ -6,6 +6,8 @@ import torch
6
 
7
  from stegno import generate, decrypt
8
  from utils import load_model
 
 
9
 
10
 
11
  def create_args():
@@ -15,7 +17,7 @@ def create_args():
15
  parser.add_argument(
16
  "--gen-model",
17
  type=str,
18
- default="openai-community/gpt2",
19
  help="Generative model (LLM) used to generate text",
20
  )
21
  parser.add_argument(
@@ -25,25 +27,25 @@ def create_args():
25
  parser.add_argument(
26
  "--gamma",
27
  type=float,
28
- default=2.0,
29
  help="Bias added to scores of tokens in valid list",
30
  )
31
  parser.add_argument(
32
  "--msg-base",
33
  type=int,
34
- default=2,
35
  help="Base of message",
36
  )
37
  parser.add_argument(
38
  "--seed-scheme",
39
  type=str,
40
- required=True,
41
  help="Scheme used to compute the seed",
42
  )
43
  parser.add_argument(
44
  "--window-length",
45
  type=int,
46
- default=1,
47
  help="Length of window to compute the seed",
48
  )
49
  parser.add_argument(
@@ -56,13 +58,13 @@ def create_args():
56
  parser.add_argument(
57
  "--num-beams",
58
  type=int,
59
- default=4,
60
  help="Number of beams used in beam search",
61
  )
62
  parser.add_argument(
63
  "--max-new-tokens-ratio",
64
  type=float,
65
- default=2,
66
  help="Ratio of max new tokens to minimum tokens required to hide message",
67
  )
68
  # Input
@@ -89,7 +91,7 @@ def create_args():
89
  "--start-pos",
90
  type=int,
91
  nargs="+",
92
- default=[0],
93
  help="Start position to input the text (not including window length). If 2 integers are provided, choose the position randomly between the two values.",
94
  )
95
  # Mode
 
6
 
7
  from stegno import generate, decrypt
8
  from utils import load_model
9
+ from global_config import GlobalConfig
10
+ from model_factory import ModelFactory
11
 
12
 
13
  def create_args():
 
17
  parser.add_argument(
18
  "--gen-model",
19
  type=str,
20
+ default=GlobalConfig.get("encrypt.default", "gen_model"),
21
  help="Generative model (LLM) used to generate text",
22
  )
23
  parser.add_argument(
 
27
  parser.add_argument(
28
  "--gamma",
29
  type=float,
30
+ default=GlobalConfig.get("encrypt.default", "gamma"),
31
  help="Bias added to scores of tokens in valid list",
32
  )
33
  parser.add_argument(
34
  "--msg-base",
35
  type=int,
36
+ default=GlobalConfig.get("encrypt.default", "msg_base"),
37
  help="Base of message",
38
  )
39
  parser.add_argument(
40
  "--seed-scheme",
41
  type=str,
42
+ default=GlobalConfig.get("encrypt.default", "seed_scheme"),
43
  help="Scheme used to compute the seed",
44
  )
45
  parser.add_argument(
46
  "--window-length",
47
  type=int,
48
+ default=GlobalConfig.get("encrypt.default", "window_length"),
49
  help="Length of window to compute the seed",
50
  )
51
  parser.add_argument(
 
58
  parser.add_argument(
59
  "--num-beams",
60
  type=int,
61
+ default=GlobalConfig.get("encrypt.default", "num_beams"),
62
  help="Number of beams used in beam search",
63
  )
64
  parser.add_argument(
65
  "--max-new-tokens-ratio",
66
  type=float,
67
+ default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"),
68
  help="Ratio of max new tokens to minimum tokens required to hide message",
69
  )
70
  # Input
 
91
  "--start-pos",
92
  type=int,
93
  nargs="+",
94
+ default=[GlobalConfig.get("encrypt.default", "start_pos")],
95
  help="Start position to input the text (not including window length). If 2 integers are provided, choose the position randomly between the two values.",
96
  )
97
  # Mode
processors.py CHANGED
@@ -52,7 +52,9 @@ class BaseProcessor(object):
52
  self.rng = torch.Generator(device="cpu")
53
 
54
  # Compute the ranges of each value in base
55
- self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64).to(self.device)
 
 
56
  chunk_size = self.vocab_size / self.msg_base
57
  r = self.vocab_size % self.msg_base
58
  self.ranges[1:] = chunk_size
@@ -103,6 +105,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
103
  prompt_ids: torch.Tensor,
104
  msg: bytes,
105
  gamma: float,
 
106
  start_pos: int = 0,
107
  *args,
108
  **kwargs,
@@ -124,6 +127,15 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
124
  self.raw_msg = msg
125
  self.msg = bytes_to_base(msg, self.msg_base)
126
  self.gamma = gamma
 
 
 
 
 
 
 
 
 
127
 
128
  def __call__(
129
  self, input_ids_batch: torch.LongTensor, scores_batch: torch.FloatTensor
@@ -147,7 +159,10 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
147
  """
148
  Add the bias (gamma) to the valid list tokens
149
  """
150
- ids = self._get_valid_list_ids(input_ids, value)
 
 
 
151
  scores[ids] = scores[ids] + self.gamma
152
  return scores
153
 
 
52
  self.rng = torch.Generator(device="cpu")
53
 
54
  # Compute the ranges of each value in base
55
+ self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64).to(
56
+ self.device
57
+ )
58
  chunk_size = self.vocab_size / self.msg_base
59
  r = self.vocab_size % self.msg_base
60
  self.ranges[1:] = chunk_size
 
105
  prompt_ids: torch.Tensor,
106
  msg: bytes,
107
  gamma: float,
108
+ tokenizer,
109
  start_pos: int = 0,
110
  *args,
111
  **kwargs,
 
127
  self.raw_msg = msg
128
  self.msg = bytes_to_base(msg, self.msg_base)
129
  self.gamma = gamma
130
+ special_tokens = [
131
+ tokenizer.bos_token_id,
132
+ tokenizer.eos_token_id,
133
+ tokenizer.sep_token_id,
134
+ tokenizer.pad_token_id,
135
+ tokenizer.cls_token_id,
136
+ ]
137
+ special_tokens = [x for x in special_tokens if x is not None]
138
+ self.special_tokens = torch.tensor(special_tokens, device=self.device)
139
 
140
  def __call__(
141
  self, input_ids_batch: torch.LongTensor, scores_batch: torch.FloatTensor
 
159
  """
160
  Add the bias (gamma) to the valid list tokens
161
  """
162
+ ids = torch.cat(
163
+ [self._get_valid_list_ids(input_ids, value), self.special_tokens]
164
+ )
165
+
166
  scores[ids] = scores[ids] + self.gamma
167
  return scores
168
 
requirements.txt CHANGED
@@ -7,3 +7,4 @@ torch==2.3.0
7
  cryptography==42.0.8
8
  fastapi
9
  gradio
 
 
7
  cryptography==42.0.8
8
  fastapi
9
  gradio
10
+ uvicorn
schemes.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from global_config import GlobalConfig
3
+
4
+
5
+ class EncryptionBody(BaseModel):
6
+ prompt: str
7
+ msg: str
8
+ gen_model: str = GlobalConfig.get("encrypt.default", "gen_model")
9
+ start_pos: int = GlobalConfig.get("encrypt.default", "start_pos")
10
+
11
+ gamma: float = GlobalConfig.get("encrypt.default", "gamma")
12
+ msg_base: int = GlobalConfig.get("encrypt.default", "msg_base")
13
+
14
+ seed_scheme: str = GlobalConfig.get("encrypt.default", "seed_scheme")
15
+ window_length: int = GlobalConfig.get(
16
+ "encrypt.default", "window_length"
17
+ )
18
+ private_key: int = GlobalConfig.get("encrypt.default", "private_key")
19
+ max_new_tokens_ratio: float = GlobalConfig.get(
20
+ "encrypt.default", "max_new_tokens_ratio"
21
+ )
22
+ num_beams: int = GlobalConfig.get("encrypt.default", "num_beams")
23
+
24
+ class DecryptionBody(BaseModel):
25
+ text: str
26
+ gen_model: str = GlobalConfig.get("decrypt.default", "gen_model")
27
+ msg_base: int = GlobalConfig.get("decrypt.default", "msg_base")
28
+
29
+ seed_scheme: str = GlobalConfig.get("decrypt.default", "seed_scheme")
30
+ window_length: int = GlobalConfig.get(
31
+ "decrypt.default", "window_length"
32
+ )
33
+ private_key: int = GlobalConfig.get("decrypt.default", "private_key")
stegno.py CHANGED
@@ -54,6 +54,7 @@ def generate(
54
  gamma=gamma,
55
  msg_base=msg_base,
56
  vocab=list(tokenizer.get_vocab().values()),
 
57
  device=model.device,
58
  seed_scheme=seed_scheme,
59
  window_length=window_length,
 
54
  gamma=gamma,
55
  msg_base=msg_base,
56
  vocab=list(tokenizer.get_vocab().values()),
57
+ tokenizer=tokenizer,
58
  device=model.device,
59
  seed_scheme=seed_scheme,
60
  window_length=window_length,
utils.py CHANGED
@@ -50,3 +50,8 @@ def load_model(name: str, device: torch.device):
50
  tokenizer = AutoTokenizer.from_pretrained(name)
51
 
52
  return model, tokenizer
 
 
 
 
 
 
50
  tokenizer = AutoTokenizer.from_pretrained(name)
51
 
52
  return model, tokenizer
53
+
54
+ def static_init(cls):
55
+ if getattr(cls, "__static_init__", None):
56
+ cls.__static_init__()
57
+ return cls