Spaces:
Running
Running
File size: 5,365 Bytes
7235a64 1f125f1 7235a64 1f125f1 7235a64 1f125f1 cc8b2eb 1f125f1 ee83d59 7235a64 1f125f1 ee83d59 7235a64 cc8b2eb 7235a64 ee83d59 7235a64 ee83d59 7235a64 1f125f1 7235a64 1f125f1 7235a64 1f125f1 7235a64 1f125f1 cc8b2eb 1f125f1 ee83d59 7235a64 1f125f1 ee83d59 7235a64 1f125f1 7235a64 1f125f1 7235a64 8d94857 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import gradio as gr
from model_factory import ModelFactory
from stegno import generate, decrypt
from seed_scheme_factory import SeedSchemeFactory
from global_config import GlobalConfig
def enc_fn(
gen_model: str,
prompt: str,
msg: str,
start_pos: int,
delta: float,
msg_base: int,
seed_scheme: str,
window_length: int,
private_key: int,
max_new_tokens_ratio: float,
num_beams: int,
repetition_penalty: float,
):
model, tokenizer = ModelFactory.load_model(gen_model)
text, msg_rate, tokens_info = generate(
tokenizer=tokenizer,
model=model,
prompt=prompt,
msg=str.encode(msg),
start_pos_p=[start_pos],
delta=delta,
msg_base=msg_base,
seed_scheme=seed_scheme,
window_length=window_length,
private_key=private_key,
max_new_tokens_ratio=max_new_tokens_ratio,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
)
highlight_base = []
for token in tokens_info:
stat = None
if token["base_msg"] != -1:
if token["base_msg"] == token["base_enc"]:
stat = "correct"
else:
stat = "wrong"
highlight_base.append((repr(token["token"])[1:-1], stat))
highlight_byte = []
for i, token in enumerate(tokens_info):
if i == 0 or tokens_info[i - 1]["byte_id"] != token["byte_id"]:
stat = None
if token["byte_msg"] != -1:
if token["byte_msg"] == token["byte_enc"]:
stat = "correct"
else:
stat = "wrong"
highlight_byte.append([repr(token["token"])[1:-1], stat])
else:
highlight_byte[-1][0] += repr(token["token"])[1:-1]
return text, highlight_base, highlight_byte, round(msg_rate * 100, 2)
def dec_fn(
gen_model: str,
text: str,
msg_base: int,
seed_scheme: str,
window_length: int,
private_key: int,
):
model, tokenizer = ModelFactory.load_model(gen_model)
msgs = decrypt(
tokenizer=tokenizer,
device=model.device,
text=text,
msg_base=msg_base,
seed_scheme=seed_scheme,
window_length=window_length,
private_key=private_key,
)
msg_text = ""
for i, msg in enumerate(msgs):
msg_text += f"Shift {i}: {msg}\n\n"
return msg_text
if __name__ == "__main__":
enc = gr.Interface(
fn=enc_fn,
inputs=[
gr.Dropdown(
value=GlobalConfig.get("encrypt.default", "gen_model"),
choices=ModelFactory.get_models_names(),
),
gr.Textbox(),
gr.Textbox(),
gr.Number(int(GlobalConfig.get("encrypt.default", "start_pos"))),
gr.Number(float(GlobalConfig.get("encrypt.default", "delta"))),
gr.Number(int(GlobalConfig.get("encrypt.default", "msg_base"))),
gr.Dropdown(
value=GlobalConfig.get("encrypt.default", "seed_scheme"),
choices=SeedSchemeFactory.get_schemes_name(),
),
gr.Number(
int(GlobalConfig.get("encrypt.default", "window_length"))
),
gr.Number(int(GlobalConfig.get("encrypt.default", "private_key"))),
gr.Number(
float(
GlobalConfig.get("encrypt.default", "max_new_tokens_ratio")
)
),
gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))),
gr.Number(float(GlobalConfig.get("encrypt.default", "repetition_penalty"))),
],
outputs=[
gr.Textbox(
label="Text containing message",
show_label=True,
show_copy_button=True,
),
gr.HighlightedText(
label="Text containing message (Base highlighted)",
combine_adjacent=False,
show_legend=True,
color_map={"correct": "green", "wrong": "red"},
),
gr.HighlightedText(
label="Text containing message (Byte highlighted)",
combine_adjacent=False,
show_legend=True,
color_map={"correct": "green", "wrong": "red"},
),
gr.Number(label="Percentage of message in text", show_label=True),
],
)
dec = gr.Interface(
fn=dec_fn,
inputs=[
gr.Dropdown(
value=GlobalConfig.get("decrypt.default", "gen_model"),
choices=ModelFactory.get_models_names(),
),
gr.Textbox(),
gr.Number(int(GlobalConfig.get("decrypt.default", "msg_base"))),
gr.Dropdown(
value=GlobalConfig.get("decrypt.default", "seed_scheme"),
choices=SeedSchemeFactory.get_schemes_name(),
),
gr.Number(
int(GlobalConfig.get("decrypt.default", "window_length"))
),
gr.Number(int(GlobalConfig.get("decrypt.default", "private_key"))),
],
outputs=[
gr.Textbox(label="Message", show_label=True),
],
)
app = gr.TabbedInterface([enc, dec], ["Encrytion", "Decryption"])
app.launch(share=True)
|