Spaces:
Sleeping
Sleeping
import base64 | |
import json | |
import torch | |
from fastapi import FastAPI | |
from fastapi.openapi.utils import get_openapi | |
import uvicorn | |
from stegno import generate, decrypt | |
from utils import load_model | |
from seed_scheme_factory import SeedSchemeFactory | |
from model_factory import ModelFactory | |
from global_config import GlobalConfig | |
from schemes import DecryptionBody, EncryptionBody | |
app = FastAPI() | |
with open("resources/examples.json", "r") as f: | |
examples = json.load(f) | |
async def encrypt_api( | |
body: EncryptionBody, | |
): | |
byte_msg = base64.b64decode(body.msg) | |
model, tokenizer = ModelFactory.load_model(body.gen_model) | |
texts, msgs_rates, tokens_infos = generate( | |
tokenizer=tokenizer, | |
model=model, | |
prompt=body.prompt, | |
msg=byte_msg, | |
start_pos_p=[body.start_pos], | |
delta=body.delta, | |
msg_base=body.msg_base, | |
seed_scheme=body.seed_scheme, | |
window_length=body.window_length, | |
private_key=body.private_key, | |
max_new_tokens_ratio=body.max_new_tokens_ratio, | |
num_beams=body.num_beams, | |
repetition_penalty=body.repetition_penalty, | |
) | |
return { | |
"texts": texts, | |
"msgs_rates": msgs_rates, | |
"tokens_info": tokens_infos, | |
} | |
async def decrypt_api(body: DecryptionBody): | |
model, tokenizer = ModelFactory.load_model(body.gen_model) | |
msgs = decrypt( | |
tokenizer=tokenizer, | |
device=model.device, | |
text=body.text, | |
msg_base=body.msg_base, | |
seed_scheme=body.seed_scheme, | |
window_length=body.window_length, | |
private_key=body.private_key, | |
) | |
msg_b64 = {} | |
for i, s_msg in enumerate(msgs): | |
msg_b64[i] = [] | |
for msg in s_msg: | |
msg_b64[i].append(base64.b64encode(msg)) | |
return msg_b64 | |
async def default_config(): | |
configs = { | |
"default": { | |
"encrypt": { | |
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"), | |
"start_pos": GlobalConfig.get("encrypt.default", "start_pos"), | |
"delta": GlobalConfig.get("encrypt.default", "delta"), | |
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"), | |
"seed_scheme": GlobalConfig.get( | |
"encrypt.default", "seed_scheme" | |
), | |
"window_length": GlobalConfig.get( | |
"encrypt.default", "window_length" | |
), | |
"private_key": GlobalConfig.get( | |
"encrypt.default", "private_key" | |
), | |
"min_new_tokens_ratio": GlobalConfig.get( | |
"encrypt.default", "min_new_tokens_ratio" | |
), | |
"max_new_tokens_ratio": GlobalConfig.get( | |
"encrypt.default", "max_new_tokens_ratio" | |
), | |
"do_sample": GlobalConfig.get("encrypt.default", "do_sample"), | |
"num_beams": GlobalConfig.get("encrypt.default", "num_beams"), | |
"repetition_penalty": GlobalConfig.get( | |
"encrypt.default", "repetition_penalty" | |
), | |
}, | |
"decrypt": { | |
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"), | |
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"), | |
"seed_scheme": GlobalConfig.get( | |
"encrypt.default", "seed_scheme" | |
), | |
"window_length": GlobalConfig.get( | |
"encrypt.default", "window_length" | |
), | |
"private_key": GlobalConfig.get( | |
"encrypt.default", "private_key" | |
), | |
}, | |
}, | |
"seed_schemes": SeedSchemeFactory.get_schemes_name(), | |
"models": ModelFactory.get_models_names(), | |
} | |
return configs | |
if __name__ == "__main__": | |
# The following are mainly used to satisfy the linter | |
host = GlobalConfig.get("server", "host") | |
host = str(host) if host is not None else "0.0.0.0" | |
port = GlobalConfig.get("server", "port") | |
port = int(port) if port is not None else 8000 | |
workers = GlobalConfig.get("server", "workers") | |
workers = int(workers) if workers is not None else 1 | |
uvicorn.run("api:app", host=host, port=port, workers=workers) | |