tnk2908 commited on
Commit
7235a64
1 Parent(s): 7e8d9b9

Add gradio demo

Browse files
Files changed (3) hide show
  1. demo.py +104 -0
  2. requirements.txt +1 -0
  3. seed_scheme_factory.py +6 -1
demo.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ from utils import load_model
5
+ from stegno import generate, decrypt
6
+ from seed_scheme_factory import SeedSchemeFactory
7
+
8
+
9
+ def enc_fn(
10
+ gen_model: str = "openai-community/gpt2",
11
+ device: str = "cpu",
12
+ prompt: str = "",
13
+ msg: str = "",
14
+ start_pos: int = 0,
15
+ gamma: float = 2.0,
16
+ msg_base: int = 2,
17
+ seed_scheme: str = "dummy_hash",
18
+ window_length: int = 1,
19
+ private_key: int = 0,
20
+ max_new_tokens_ratio: float = 2,
21
+ num_beams: int = 4,
22
+ ):
23
+ model, tokenizer = load_model(gen_model, torch.device(device))
24
+ text, msg_rate = generate(
25
+ tokenizer=tokenizer,
26
+ model=model,
27
+ prompt=prompt,
28
+ msg=str.encode(msg),
29
+ start_pos_p=[start_pos],
30
+ gamma=gamma,
31
+ msg_base=msg_base,
32
+ seed_scheme=seed_scheme,
33
+ window_length=window_length,
34
+ private_key=private_key,
35
+ max_new_tokens_ratio=max_new_tokens_ratio,
36
+ num_beams=num_beams,
37
+ )
38
+ return text, msg_rate
39
+
40
+
41
+ def dec_fn(
42
+ gen_model: str = "openai-community/gpt2",
43
+ device: str = "cpu",
44
+ text: str = "",
45
+ msg_base: int = 2,
46
+ seed_scheme: str = "dummy_hash",
47
+ window_length: int = 1,
48
+ private_key: int = 0,
49
+ ):
50
+ model, tokenizer = load_model(gen_model, torch.device(device))
51
+ msgs = decrypt(
52
+ tokenizer=tokenizer,
53
+ device=model.device,
54
+ text=text,
55
+ msg_base=msg_base,
56
+ seed_scheme=seed_scheme,
57
+ window_length=window_length,
58
+ private_key=private_key,
59
+ )
60
+ msg_text = ""
61
+ for i, msg in enumerate(msgs):
62
+ msg_text += f"Shift {i}: {msg}\n\n"
63
+ return msg_text
64
+
65
+
66
+ if __name__ == "__main__":
67
+ enc = gr.Interface(
68
+ fn=enc_fn,
69
+ inputs=[
70
+ gr.Textbox("openai-community/gpt2"),
71
+ gr.Textbox("cpu"),
72
+ gr.Textbox(),
73
+ gr.Textbox(),
74
+ gr.Number(),
75
+ gr.Number(10.0),
76
+ gr.Number(2),
77
+ gr.Dropdown(value="dummy_hash", choices=SeedSchemeFactory.get_schemes_name()),
78
+ gr.Number(1),
79
+ gr.Number(),
80
+ gr.Number(2),
81
+ gr.Number(4),
82
+ ],
83
+ outputs=[
84
+ gr.Textbox(label="Text containing message", show_label=True, show_copy_button=True),
85
+ gr.Number(label="Percentage of message in text", show_label=True),
86
+ ],
87
+ )
88
+ dec = gr.Interface(
89
+ fn=dec_fn,
90
+ inputs=[
91
+ gr.Textbox("openai-community/gpt2"),
92
+ gr.Textbox("cpu"),
93
+ gr.Textbox(),
94
+ gr.Number(2),
95
+ gr.Dropdown(value="dummy_hash", choices=SeedSchemeFactory.get_schemes_name()),
96
+ gr.Number(1),
97
+ gr.Number(),
98
+ ],
99
+ outputs=[
100
+ gr.Textbox(label="Message", show_label=True),
101
+ ],
102
+ )
103
+ app = gr.TabbedInterface([enc, dec], ["Encrytion", "Decryption"])
104
+ app.launch()
requirements.txt CHANGED
@@ -6,3 +6,4 @@ scikit-learn==1.5.0
6
  torch==2.3.0
7
  cryptography==42.0.8
8
  fastapi
 
 
6
  torch==2.3.0
7
  cryptography==42.0.8
8
  fastapi
9
+ gradio
seed_scheme_factory.py CHANGED
@@ -6,6 +6,10 @@ import torch
6
  class SeedSchemeFactory:
7
  registry = {}
8
 
 
 
 
 
9
  @classmethod
10
  def register(cls, name: str):
11
  """
@@ -37,8 +41,9 @@ class SeedSchemeFactory:
37
  return None
38
 
39
 
40
- class SeedScheme():
41
  def __call__(self, input_ids: torch.Tensor) -> int:
42
  return 0
43
 
 
44
  from seed_schemes import *
 
6
  class SeedSchemeFactory:
7
  registry = {}
8
 
9
+ @classmethod
10
+ def get_schemes_name(cls) -> list[str]:
11
+ return list(cls.registry.keys())
12
+
13
  @classmethod
14
  def register(cls, name: str):
15
  """
 
41
  return None
42
 
43
 
44
+ class SeedScheme:
45
  def __call__(self, input_ids: torch.Tensor) -> int:
46
  return 0
47
 
48
+
49
  from seed_schemes import *