Spaces:
Running
Running
File size: 3,964 Bytes
1f125f1 441b276 1f125f1 441b276 1f125f1 441b276 1f125f1 441b276 1f125f1 ee83d59 441b276 1f125f1 ee83d59 441b276 ee83d59 441b276 1f125f1 441b276 1f125f1 441b276 1f125f1 ee83d59 1f125f1 ee83d59 1f125f1 ee83d59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import base64
import torch
from fastapi import FastAPI
import uvicorn
from stegno import generate, decrypt
from utils import load_model
from seed_scheme_factory import SeedSchemeFactory
from model_factory import ModelFactory
from global_config import GlobalConfig
from schemes import DecryptionBody, EncryptionBody
app = FastAPI()
@app.post("/encrypt")
async def encrypt_api(
body: EncryptionBody,
):
model, tokenizer = ModelFactory.load_model(body.gen_model)
text, msg_rate, tokens_info = generate(
tokenizer=tokenizer,
model=model,
prompt=body.prompt,
msg=str.encode(body.msg),
start_pos_p=[body.start_pos],
gamma=body.gamma,
msg_base=body.msg_base,
seed_scheme=body.seed_scheme,
window_length=body.window_length,
private_key=body.private_key,
max_new_tokens_ratio=body.max_new_tokens_ratio,
num_beams=body.num_beams,
repetition_penalty=body.repetition_penalty,
)
return {"text": text, "msg_rate": msg_rate, "tokens_info": tokens_info}
@app.post("/decrypt")
async def decrypt_api(body: DecryptionBody):
model, tokenizer = ModelFactory.load_model(body.gen_model)
msgs = decrypt(
tokenizer=tokenizer,
device=model.device,
text=body.text,
msg_base=body.msg_base,
seed_scheme=body.seed_scheme,
window_length=body.window_length,
private_key=body.private_key,
)
msg_b64 = {}
for i, s_msg in enumerate(msgs):
msg_b64[i] = []
for msg in s_msg:
msg_b64[i].append(base64.b64encode(msg))
return msg_b64
@app.get("/configs")
async def default_config():
configs = {
"default": {
"encrypt": {
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
"start_pos": GlobalConfig.get("encrypt.default", "start_pos"),
"gamma": GlobalConfig.get("encrypt.default", "gamma"),
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
"seed_scheme": GlobalConfig.get(
"encrypt.default", "seed_scheme"
),
"window_length": GlobalConfig.get(
"encrypt.default", "window_length"
),
"private_key": GlobalConfig.get(
"encrypt.default", "private_key"
),
"max_new_tokens_ratio": GlobalConfig.get(
"encrypt.default", "max_new_tokens_ratio"
),
"num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
"repetition_penalty": GlobalConfig.get(
"encrypt.default", "repetition_penalty"
),
},
"decrypt": {
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
"seed_scheme": GlobalConfig.get(
"encrypt.default", "seed_scheme"
),
"window_length": GlobalConfig.get(
"encrypt.default", "window_length"
),
"private_key": GlobalConfig.get(
"encrypt.default", "private_key"
),
},
},
"seed_schemes": SeedSchemeFactory.get_schemes_name(),
"models": ModelFactory.get_models_names(),
}
return configs
if __name__ == "__main__":
# The following are mainly used to satisfy the linter
host = GlobalConfig.get("server", "host")
host = str(host) if host is not None else "0.0.0.0"
port = GlobalConfig.get("server", "port")
port = int(port) if port is not None else 8000
workers = GlobalConfig.get("server", "workers")
workers = int(workers) if workers is not None else 1
uvicorn.run("api:app", host=host, port=port, workers=workers)
|