Spaces:
Sleeping
Sleeping
Add restrictions and rename parameters to meet the one written in report
Browse files- api.py +2 -2
- config.ini +2 -2
- demo.py +3 -3
- main.py +4 -4
- processors.py +5 -5
- schemes.py +77 -23
- stegno.py +4 -4
api.py
CHANGED
@@ -25,7 +25,7 @@ async def encrypt_api(
|
|
25 |
prompt=body.prompt,
|
26 |
msg=str.encode(body.msg),
|
27 |
start_pos_p=[body.start_pos],
|
28 |
-
|
29 |
msg_base=body.msg_base,
|
30 |
seed_scheme=body.seed_scheme,
|
31 |
window_length=body.window_length,
|
@@ -64,7 +64,7 @@ async def default_config():
|
|
64 |
"encrypt": {
|
65 |
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
|
66 |
"start_pos": GlobalConfig.get("encrypt.default", "start_pos"),
|
67 |
-
"
|
68 |
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
|
69 |
"seed_scheme": GlobalConfig.get(
|
70 |
"encrypt.default", "seed_scheme"
|
|
|
25 |
prompt=body.prompt,
|
26 |
msg=str.encode(body.msg),
|
27 |
start_pos_p=[body.start_pos],
|
28 |
+
delta=body.delta,
|
29 |
msg_base=body.msg_base,
|
30 |
seed_scheme=body.seed_scheme,
|
31 |
window_length=body.window_length,
|
|
|
64 |
"encrypt": {
|
65 |
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
|
66 |
"start_pos": GlobalConfig.get("encrypt.default", "start_pos"),
|
67 |
+
"delta": GlobalConfig.get("encrypt.default", "delta"),
|
68 |
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
|
69 |
"seed_scheme": GlobalConfig.get(
|
70 |
"encrypt.default", "seed_scheme"
|
config.ini
CHANGED
@@ -22,12 +22,12 @@ opt_13b = str:facebook/opt-13b
|
|
22 |
[models.params]
|
23 |
dtype = str:bfloat16
|
24 |
load_device = str:cpu
|
25 |
-
run_device = str:
|
26 |
|
27 |
[encrypt.default]
|
28 |
gen_model = str:gpt2
|
29 |
start_pos = int:0
|
30 |
-
|
31 |
msg_base = int:2
|
32 |
seed_scheme = str:sha_left_hash
|
33 |
window_length = int:1
|
|
|
22 |
[models.params]
|
23 |
dtype = str:bfloat16
|
24 |
load_device = str:cpu
|
25 |
+
run_device = str:cpu
|
26 |
|
27 |
[encrypt.default]
|
28 |
gen_model = str:gpt2
|
29 |
start_pos = int:0
|
30 |
+
delta = float:10.0
|
31 |
msg_base = int:2
|
32 |
seed_scheme = str:sha_left_hash
|
33 |
window_length = int:1
|
demo.py
CHANGED
@@ -12,7 +12,7 @@ def enc_fn(
|
|
12 |
prompt: str,
|
13 |
msg: str,
|
14 |
start_pos: int,
|
15 |
-
|
16 |
msg_base: int,
|
17 |
seed_scheme: str,
|
18 |
window_length: int,
|
@@ -28,7 +28,7 @@ def enc_fn(
|
|
28 |
prompt=prompt,
|
29 |
msg=str.encode(msg),
|
30 |
start_pos_p=[start_pos],
|
31 |
-
|
32 |
msg_base=msg_base,
|
33 |
seed_scheme=seed_scheme,
|
34 |
window_length=window_length,
|
@@ -98,7 +98,7 @@ if __name__ == "__main__":
|
|
98 |
gr.Textbox(),
|
99 |
gr.Textbox(),
|
100 |
gr.Number(int(GlobalConfig.get("encrypt.default", "start_pos"))),
|
101 |
-
gr.Number(float(GlobalConfig.get("encrypt.default", "
|
102 |
gr.Number(int(GlobalConfig.get("encrypt.default", "msg_base"))),
|
103 |
gr.Dropdown(
|
104 |
value=GlobalConfig.get("encrypt.default", "seed_scheme"),
|
|
|
12 |
prompt: str,
|
13 |
msg: str,
|
14 |
start_pos: int,
|
15 |
+
delta: float,
|
16 |
msg_base: int,
|
17 |
seed_scheme: str,
|
18 |
window_length: int,
|
|
|
28 |
prompt=prompt,
|
29 |
msg=str.encode(msg),
|
30 |
start_pos_p=[start_pos],
|
31 |
+
delta=delta,
|
32 |
msg_base=msg_base,
|
33 |
seed_scheme=seed_scheme,
|
34 |
window_length=window_length,
|
|
|
98 |
gr.Textbox(),
|
99 |
gr.Textbox(),
|
100 |
gr.Number(int(GlobalConfig.get("encrypt.default", "start_pos"))),
|
101 |
+
gr.Number(float(GlobalConfig.get("encrypt.default", "delta"))),
|
102 |
gr.Number(int(GlobalConfig.get("encrypt.default", "msg_base"))),
|
103 |
gr.Dropdown(
|
104 |
value=GlobalConfig.get("encrypt.default", "seed_scheme"),
|
main.py
CHANGED
@@ -25,9 +25,9 @@ def create_args():
|
|
25 |
)
|
26 |
# Stenography params
|
27 |
parser.add_argument(
|
28 |
-
"--
|
29 |
type=float,
|
30 |
-
default=GlobalConfig.get("encrypt.default", "
|
31 |
help="Bias added to scores of tokens in valid list",
|
32 |
)
|
33 |
parser.add_argument(
|
@@ -162,7 +162,7 @@ def main(args):
|
|
162 |
print("- " * (os.get_terminal_size().columns // 2))
|
163 |
print(args.msg)
|
164 |
print("- " * (os.get_terminal_size().columns // 2))
|
165 |
-
print(f"
|
166 |
print(f" Message Base: {args.msg_base}")
|
167 |
print(f" Seed Scheme: {args.seed_scheme}")
|
168 |
print(f" Window Length: {args.window_length}")
|
@@ -177,7 +177,7 @@ def main(args):
|
|
177 |
prompt=args.prompt,
|
178 |
msg=args.msg,
|
179 |
start_pos_p=args.start_pos,
|
180 |
-
|
181 |
msg_base=args.msg_base,
|
182 |
seed_scheme=args.seed_scheme,
|
183 |
window_length=args.window_length,
|
|
|
25 |
)
|
26 |
# Stenography params
|
27 |
parser.add_argument(
|
28 |
+
"--delta",
|
29 |
type=float,
|
30 |
+
default=GlobalConfig.get("encrypt.default", "delta"),
|
31 |
help="Bias added to scores of tokens in valid list",
|
32 |
)
|
33 |
parser.add_argument(
|
|
|
162 |
print("- " * (os.get_terminal_size().columns // 2))
|
163 |
print(args.msg)
|
164 |
print("- " * (os.get_terminal_size().columns // 2))
|
165 |
+
print(f" delta: {args.delta}")
|
166 |
print(f" Message Base: {args.msg_base}")
|
167 |
print(f" Seed Scheme: {args.seed_scheme}")
|
168 |
print(f" Window Length: {args.window_length}")
|
|
|
177 |
prompt=args.prompt,
|
178 |
msg=args.msg,
|
179 |
start_pos_p=args.start_pos,
|
180 |
+
delta=args.delta,
|
181 |
msg_base=args.msg_base,
|
182 |
seed_scheme=args.seed_scheme,
|
183 |
window_length=args.window_length,
|
processors.py
CHANGED
@@ -104,7 +104,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
104 |
self,
|
105 |
prompt_ids: torch.Tensor,
|
106 |
msg: bytes,
|
107 |
-
|
108 |
tokenizer,
|
109 |
start_pos: int = 0,
|
110 |
*args,
|
@@ -113,7 +113,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
113 |
"""
|
114 |
Args:
|
115 |
msg: message to hide in the text.
|
116 |
-
|
117 |
"""
|
118 |
super().__init__(*args, **kwargs)
|
119 |
if prompt_ids.size(0) != 1:
|
@@ -126,7 +126,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
126 |
|
127 |
self.raw_msg = msg
|
128 |
self.msg = bytes_to_base(msg, self.msg_base)
|
129 |
-
self.
|
130 |
self.tokenizer = tokenizer
|
131 |
special_tokens = [
|
132 |
tokenizer.bos_token_id,
|
@@ -158,13 +158,13 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
158 |
self, input_ids: torch.Tensor, scores: torch.Tensor, value: int
|
159 |
):
|
160 |
"""
|
161 |
-
Add the bias (
|
162 |
"""
|
163 |
ids = torch.cat(
|
164 |
[self._get_valid_list_ids(input_ids, value), self.special_tokens]
|
165 |
)
|
166 |
|
167 |
-
scores[ids] = scores[ids] + self.
|
168 |
return scores
|
169 |
|
170 |
def get_message_len(self):
|
|
|
104 |
self,
|
105 |
prompt_ids: torch.Tensor,
|
106 |
msg: bytes,
|
107 |
+
delta: float,
|
108 |
tokenizer,
|
109 |
start_pos: int = 0,
|
110 |
*args,
|
|
|
113 |
"""
|
114 |
Args:
|
115 |
msg: message to hide in the text.
|
116 |
+
delta: bias add to scores of token in valid list.
|
117 |
"""
|
118 |
super().__init__(*args, **kwargs)
|
119 |
if prompt_ids.size(0) != 1:
|
|
|
126 |
|
127 |
self.raw_msg = msg
|
128 |
self.msg = bytes_to_base(msg, self.msg_base)
|
129 |
+
self.delta = delta
|
130 |
self.tokenizer = tokenizer
|
131 |
special_tokens = [
|
132 |
tokenizer.bos_token_id,
|
|
|
158 |
self, input_ids: torch.Tensor, scores: torch.Tensor, value: int
|
159 |
):
|
160 |
"""
|
161 |
+
Add the bias (delta) to the valid list tokens
|
162 |
"""
|
163 |
ids = torch.cat(
|
164 |
[self._get_valid_list_ids(input_ids, value), self.special_tokens]
|
165 |
)
|
166 |
|
167 |
+
scores[ids] = scores[ids] + self.delta
|
168 |
return scores
|
169 |
|
170 |
def get_message_len(self):
|
schemes.py
CHANGED
@@ -1,34 +1,88 @@
|
|
1 |
-
from pydantic import BaseModel
|
2 |
from global_config import GlobalConfig
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
class EncryptionBody(BaseModel):
|
6 |
-
prompt: str
|
7 |
-
msg: str
|
8 |
-
gen_model:
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
seed_scheme:
|
15 |
-
|
16 |
-
"
|
17 |
)
|
18 |
-
|
19 |
-
|
20 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
)
|
22 |
-
num_beams: int = GlobalConfig.get("encrypt.default", "num_beams")
|
23 |
-
repetition_penalty: float = GlobalConfig.get('encrypt.default', "repetition_penalty")
|
24 |
|
25 |
-
class DecryptionBody(BaseModel):
|
26 |
-
text: str
|
27 |
-
gen_model: str = GlobalConfig.get("decrypt.default", "gen_model")
|
28 |
-
msg_base: int = GlobalConfig.get("decrypt.default", "msg_base")
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
)
|
34 |
-
private_key: int = GlobalConfig.get("decrypt.default", "private_key")
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
from global_config import GlobalConfig
|
3 |
+
from model_factory import ModelFactory
|
4 |
+
from seed_scheme_factory import SeedSchemeFactory
|
5 |
+
from typing import Literal
|
6 |
|
7 |
|
8 |
class EncryptionBody(BaseModel):
|
9 |
+
prompt: str = Field(title="Prompt used to generate text")
|
10 |
+
msg: str = Field(title="Message wanted to hide")
|
11 |
+
gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
|
12 |
+
default=GlobalConfig.get("encrypt.default", "gen_model"),
|
13 |
+
title="LLM used to generate text",
|
14 |
+
)
|
15 |
+
start_pos: int = Field(
|
16 |
+
default=GlobalConfig.get("encrypt.default", "start_pos"),
|
17 |
+
title="Start position to encrypt the message",
|
18 |
+
ge=0,
|
19 |
+
)
|
20 |
|
21 |
+
delta: float = Field(
|
22 |
+
default=GlobalConfig.get("encrypt.default", "delta"),
|
23 |
+
title="Hardness parameters",
|
24 |
+
gt=0,
|
25 |
+
)
|
26 |
+
msg_base: int = Field(
|
27 |
+
default=GlobalConfig.get("encrypt.default", "msg_base"),
|
28 |
+
title="Base of message used in base-encoding",
|
29 |
+
ge=2,
|
30 |
+
)
|
31 |
|
32 |
+
seed_scheme: Literal[tuple(SeedSchemeFactory.get_schemes_name())] = Field(
|
33 |
+
default=GlobalConfig.get("encrypt.default", "seed_scheme"),
|
34 |
+
title="Scheme used to compute seed for PRF",
|
35 |
)
|
36 |
+
window_length: int = Field(
|
37 |
+
default=GlobalConfig.get("encrypt.default", "window_length"),
|
38 |
+
title="Window length (context size) used to compute the seed for PRF",
|
39 |
+
ge=1,
|
40 |
+
)
|
41 |
+
private_key: int = Field(
|
42 |
+
default=GlobalConfig.get("encrypt.default", "private_key"),
|
43 |
+
title="Private key used to compute the seed for PRF",
|
44 |
+
ge=0,
|
45 |
+
)
|
46 |
+
max_new_tokens_ratio: float = Field(
|
47 |
+
default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"),
|
48 |
+
title="Max length of generated text compared to the minimum length required to hide the message",
|
49 |
+
ge=1,
|
50 |
+
)
|
51 |
+
num_beams: int = Field(
|
52 |
+
default=GlobalConfig.get("encrypt.default", "num_beams"),
|
53 |
+
title="Number of beams used in beam search",
|
54 |
+
ge=1,
|
55 |
+
)
|
56 |
+
|
57 |
+
repetition_penalty: float = Field(
|
58 |
+
default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
|
59 |
+
title="Penalty used to avoid repetition when sampling tokens",
|
60 |
+
ge=1,
|
61 |
)
|
|
|
|
|
62 |
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
class DecryptionBody(BaseModel):
|
65 |
+
text: str = Field(title="Text containing the message")
|
66 |
+
gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
|
67 |
+
default=GlobalConfig.get("decrypt.default", "gen_model"),
|
68 |
+
title="LLM used to generate text",
|
69 |
+
)
|
70 |
+
msg_base: int = Field(
|
71 |
+
default=GlobalConfig.get("decrypt.default", "msg_base"),
|
72 |
+
title="Base of message used in base-encoding",
|
73 |
+
ge=2,
|
74 |
+
)
|
75 |
+
seed_scheme: Literal[tuple(SeedSchemeFactory.get_schemes_name())] = Field(
|
76 |
+
default=GlobalConfig.get("decrypt.default", "seed_scheme"),
|
77 |
+
title="Scheme used to compute seed for PRF",
|
78 |
+
)
|
79 |
+
window_length: int = Field(
|
80 |
+
default=GlobalConfig.get("decrypt.default", "window_length"),
|
81 |
+
title="Window length (context size) used to compute the seed for PRF",
|
82 |
+
ge=1,
|
83 |
+
)
|
84 |
+
private_key: int = Field(
|
85 |
+
default=GlobalConfig.get("decrypt.default", "private_key"),
|
86 |
+
title="Private key used to compute the seed for PRF",
|
87 |
+
ge=0,
|
88 |
)
|
|
stegno.py
CHANGED
@@ -12,7 +12,7 @@ def generate(
|
|
12 |
prompt: str,
|
13 |
msg: bytes,
|
14 |
start_pos_p: list[int],
|
15 |
-
|
16 |
msg_base: int,
|
17 |
seed_scheme: str,
|
18 |
window_length: int = 1,
|
@@ -30,7 +30,7 @@ def generate(
|
|
30 |
model: generative model to use.
|
31 |
prompt: input prompt.
|
32 |
msg: message to hide in the text.
|
33 |
-
|
34 |
msg_base: base of the message.
|
35 |
seed_scheme: scheme used to compute the seed.
|
36 |
window_length: length of window to compute the seed.
|
@@ -52,7 +52,7 @@ def generate(
|
|
52 |
prompt_ids=tokenized_input.input_ids,
|
53 |
msg=msg,
|
54 |
start_pos=start_pos,
|
55 |
-
|
56 |
msg_base=msg_base,
|
57 |
vocab=list(tokenizer.get_vocab().values()),
|
58 |
tokenizer=tokenizer,
|
@@ -107,7 +107,7 @@ def decrypt(
|
|
107 |
tokenizer: tokenizer to use.
|
108 |
text: text to decode.
|
109 |
msg_base: base of the message.
|
110 |
-
|
111 |
seed_scheme: scheme used to compute the seed.
|
112 |
window_length: length of window to compute the seed.
|
113 |
salt_key: salt to add to the seed.
|
|
|
12 |
prompt: str,
|
13 |
msg: bytes,
|
14 |
start_pos_p: list[int],
|
15 |
+
delta: float,
|
16 |
msg_base: int,
|
17 |
seed_scheme: str,
|
18 |
window_length: int = 1,
|
|
|
30 |
model: generative model to use.
|
31 |
prompt: input prompt.
|
32 |
msg: message to hide in the text.
|
33 |
+
delta: bias add to scores of token in valid list.
|
34 |
msg_base: base of the message.
|
35 |
seed_scheme: scheme used to compute the seed.
|
36 |
window_length: length of window to compute the seed.
|
|
|
52 |
prompt_ids=tokenized_input.input_ids,
|
53 |
msg=msg,
|
54 |
start_pos=start_pos,
|
55 |
+
delta=delta,
|
56 |
msg_base=msg_base,
|
57 |
vocab=list(tokenizer.get_vocab().values()),
|
58 |
tokenizer=tokenizer,
|
|
|
107 |
tokenizer: tokenizer to use.
|
108 |
text: text to decode.
|
109 |
msg_base: base of the message.
|
110 |
+
delta: bias added to scores of valid list.
|
111 |
seed_scheme: scheme used to compute the seed.
|
112 |
window_length: length of window to compute the seed.
|
113 |
salt_key: salt to add to the seed.
|