Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from model_factory import ModelFactory | |
from stegno import generate, decrypt | |
from seed_scheme_factory import SeedSchemeFactory | |
from global_config import GlobalConfig | |
def enc_fn( | |
gen_model: str, | |
prompt: str, | |
msg: str, | |
start_pos: int, | |
delta: float, | |
msg_base: int, | |
seed_scheme: str, | |
window_length: int, | |
private_key: int, | |
do_sample: bool, | |
min_new_tokens_ratio: float, | |
max_new_tokens_ratio: float, | |
num_beams: int, | |
repetition_penalty: float, | |
): | |
model, tokenizer = ModelFactory.load_model(gen_model) | |
texts, msgs_rates, tokens_infos = generate( | |
tokenizer=tokenizer, | |
model=model, | |
prompt=prompt, | |
msg=str.encode(msg), | |
start_pos_p=[start_pos], | |
delta=delta, | |
msg_base=msg_base, | |
seed_scheme=seed_scheme, | |
window_length=window_length, | |
private_key=private_key, | |
do_sample=do_sample, | |
min_new_tokens_ratio=min_new_tokens_ratio, | |
max_new_tokens_ratio=max_new_tokens_ratio, | |
num_beams=num_beams, | |
repetition_penalty=repetition_penalty, | |
) | |
highlight_base = [] | |
for token in tokens_infos[0]: | |
stat = None | |
if token["base_msg"] != -1: | |
if token["base_msg"] == token["base_enc"]: | |
stat = "correct" | |
else: | |
stat = "wrong" | |
highlight_base.append((repr(token["token"])[1:-1], stat)) | |
highlight_byte = [] | |
for i, token in enumerate(tokens_infos[0]): | |
if i == 0 or tokens_infos[0][i - 1]["byte_id"] != token["byte_id"]: | |
stat = None | |
if token["byte_msg"] != -1: | |
if token["byte_msg"] == token["byte_enc"]: | |
stat = "correct" | |
else: | |
stat = "wrong" | |
highlight_byte.append([repr(token["token"])[1:-1], stat]) | |
else: | |
highlight_byte[-1][0] += repr(token["token"])[1:-1] | |
return ( | |
texts[0], | |
highlight_base, | |
highlight_byte, | |
round(msgs_rates[0] * 100, 2), | |
) | |
def dec_fn( | |
gen_model: str, | |
text: str, | |
msg_base: int, | |
seed_scheme: str, | |
window_length: int, | |
private_key: int, | |
): | |
model, tokenizer = ModelFactory.load_model(gen_model) | |
msgs = decrypt( | |
tokenizer=tokenizer, | |
device=model.device, | |
text=text, | |
msg_base=msg_base, | |
seed_scheme=seed_scheme, | |
window_length=window_length, | |
private_key=private_key, | |
) | |
msg_text = "" | |
for i, msg in enumerate(msgs): | |
msg_text += f"Shift {i}: {msg}\n\n" | |
return msg_text | |
if __name__ == "__main__": | |
enc = gr.Interface( | |
fn=enc_fn, | |
inputs=[ | |
gr.Dropdown( | |
value=GlobalConfig.get("encrypt.default", "gen_model"), | |
choices=ModelFactory.get_models_names(), | |
), | |
gr.Textbox(), | |
gr.Textbox(), | |
gr.Number(int(GlobalConfig.get("encrypt.default", "start_pos"))), | |
gr.Number(float(GlobalConfig.get("encrypt.default", "delta"))), | |
gr.Number(int(GlobalConfig.get("encrypt.default", "msg_base"))), | |
gr.Dropdown( | |
value=GlobalConfig.get("encrypt.default", "seed_scheme"), | |
choices=SeedSchemeFactory.get_schemes_name(), | |
), | |
gr.Number( | |
int(GlobalConfig.get("encrypt.default", "window_length")) | |
), | |
gr.Number(int(GlobalConfig.get("encrypt.default", "private_key"))), | |
gr.Number(bool(GlobalConfig.get("encrypt.default", "do_sample"))), | |
gr.Number( | |
float( | |
GlobalConfig.get("encrypt.default", "min_new_tokens_ratio") | |
) | |
), | |
gr.Number( | |
float( | |
GlobalConfig.get("encrypt.default", "max_new_tokens_ratio") | |
) | |
), | |
gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))), | |
gr.Number( | |
float(GlobalConfig.get("encrypt.default", "repetition_penalty")) | |
), | |
], | |
outputs=[ | |
gr.Textbox( | |
label="Text containing message", | |
show_label=True, | |
show_copy_button=True, | |
), | |
gr.HighlightedText( | |
label="Text containing message (Base highlighted)", | |
combine_adjacent=False, | |
show_legend=True, | |
color_map={"correct": "green", "wrong": "red"}, | |
), | |
gr.HighlightedText( | |
label="Text containing message (Byte highlighted)", | |
combine_adjacent=False, | |
show_legend=True, | |
color_map={"correct": "green", "wrong": "red"}, | |
), | |
gr.Number(label="Percentage of message in text", show_label=True), | |
], | |
) | |
dec = gr.Interface( | |
fn=dec_fn, | |
inputs=[ | |
gr.Dropdown( | |
value=GlobalConfig.get("decrypt.default", "gen_model"), | |
choices=ModelFactory.get_models_names(), | |
), | |
gr.Textbox(), | |
gr.Number(int(GlobalConfig.get("decrypt.default", "msg_base"))), | |
gr.Dropdown( | |
value=GlobalConfig.get("decrypt.default", "seed_scheme"), | |
choices=SeedSchemeFactory.get_schemes_name(), | |
), | |
gr.Number( | |
int(GlobalConfig.get("decrypt.default", "window_length")) | |
), | |
gr.Number(int(GlobalConfig.get("decrypt.default", "private_key"))), | |
], | |
outputs=[ | |
gr.Textbox(label="Message", show_label=True), | |
], | |
) | |
app = gr.TabbedInterface([enc, dec], ["Encrytion", "Decryption"]) | |
app.launch(share=True) | |