import torch import gradio as gr from model_factory import ModelFactory from stegno import generate, decrypt from seed_scheme_factory import SeedSchemeFactory from global_config import GlobalConfig def enc_fn( gen_model: str, prompt: str, msg: str, start_pos: int, gamma: float, msg_base: int, seed_scheme: str, window_length: int, private_key: int, max_new_tokens_ratio: float, num_beams: int, repetition_penalty: float, ): model, tokenizer = ModelFactory.load_model(gen_model) text, msg_rate, tokens_info = generate( tokenizer=tokenizer, model=model, prompt=prompt, msg=str.encode(msg), start_pos_p=[start_pos], gamma=gamma, msg_base=msg_base, seed_scheme=seed_scheme, window_length=window_length, private_key=private_key, max_new_tokens_ratio=max_new_tokens_ratio, num_beams=num_beams, repetition_penalty=repetition_penalty, ) highlight_base = [] for token in tokens_info: stat = None if token["base_msg"] != -1: if token["base_msg"] == token["base_enc"]: stat = "correct" else: stat = "wrong" highlight_base.append((repr(token["token"])[1:-1], stat)) highlight_byte = [] for i, token in enumerate(tokens_info): if i == 0 or tokens_info[i - 1]["byte_id"] != token["byte_id"]: stat = None if token["byte_msg"] != -1: if token["byte_msg"] == token["byte_enc"]: stat = "correct" else: stat = "wrong" highlight_byte.append([repr(token["token"])[1:-1], stat]) else: highlight_byte[-1][0] += repr(token["token"])[1:-1] return text, highlight_base, highlight_byte, round(msg_rate * 100, 2) def dec_fn( gen_model: str, text: str, msg_base: int, seed_scheme: str, window_length: int, private_key: int, ): model, tokenizer = ModelFactory.load_model(gen_model) msgs = decrypt( tokenizer=tokenizer, device=model.device, text=text, msg_base=msg_base, seed_scheme=seed_scheme, window_length=window_length, private_key=private_key, ) msg_text = "" for i, msg in enumerate(msgs): msg_text += f"Shift {i}: {msg}\n\n" return msg_text if __name__ == "__main__": enc = gr.Interface( fn=enc_fn, inputs=[ gr.Dropdown( value=GlobalConfig.get("encrypt.default", "gen_model"), choices=ModelFactory.get_models_names(), ), gr.Textbox(), gr.Textbox(), gr.Number(int(GlobalConfig.get("encrypt.default", "start_pos"))), gr.Number(float(GlobalConfig.get("encrypt.default", "gamma"))), gr.Number(int(GlobalConfig.get("encrypt.default", "msg_base"))), gr.Dropdown( value=GlobalConfig.get("encrypt.default", "seed_scheme"), choices=SeedSchemeFactory.get_schemes_name(), ), gr.Number( int(GlobalConfig.get("encrypt.default", "window_length")) ), gr.Number(int(GlobalConfig.get("encrypt.default", "private_key"))), gr.Number( float( GlobalConfig.get("encrypt.default", "max_new_tokens_ratio") ) ), gr.Number(int(GlobalConfig.get("encrypt.default", "num_beams"))), gr.Number(float(GlobalConfig.get("encrypt.default", "repetition_penalty"))), ], outputs=[ gr.Textbox( label="Text containing message", show_label=True, show_copy_button=True, ), gr.HighlightedText( label="Text containing message (Base highlighted)", combine_adjacent=False, show_legend=True, color_map={"correct": "green", "wrong": "red"}, ), gr.HighlightedText( label="Text containing message (Byte highlighted)", combine_adjacent=False, show_legend=True, color_map={"correct": "green", "wrong": "red"}, ), gr.Number(label="Percentage of message in text", show_label=True), ], ) dec = gr.Interface( fn=dec_fn, inputs=[ gr.Dropdown( value=GlobalConfig.get("decrypt.default", "gen_model"), choices=ModelFactory.get_models_names(), ), gr.Textbox(), gr.Number(int(GlobalConfig.get("decrypt.default", "msg_base"))), gr.Dropdown( value=GlobalConfig.get("decrypt.default", "seed_scheme"), choices=SeedSchemeFactory.get_schemes_name(), ), gr.Number( int(GlobalConfig.get("decrypt.default", "window_length")) ), gr.Number(int(GlobalConfig.get("decrypt.default", "private_key"))), ], outputs=[ gr.Textbox(label="Message", show_label=True), ], ) app = gr.TabbedInterface([enc, dec], ["Encrytion", "Decryption"]) app.launch(share=True)