tnk2908's picture
Refactor analysis process
0186ed1
raw
history blame
4.92 kB
import base64
import json
import torch
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
import uvicorn
from stegno import generate, decrypt
from utils import load_model
from seed_scheme_factory import SeedSchemeFactory
from model_factory import ModelFactory
from global_config import GlobalConfig
from schemes import DecryptionBody, EncryptionBody
app = FastAPI()
with open("resources/examples.json", "r") as f:
examples = json.load(f)
@app.post(
"/encrypt",
responses={
200: {
"content": {
"application/json": {"example": examples["encrypt"]["response"]}
}
}
},
)
async def encrypt_api(
body: EncryptionBody,
):
byte_msg = base64.b64decode(body.msg)
model, tokenizer = ModelFactory.load_model(body.gen_model)
texts, msgs_rates, tokens_infos = generate(
tokenizer=tokenizer,
model=model,
prompt=body.prompt,
msg=byte_msg,
start_pos_p=[body.start_pos],
delta=body.delta,
msg_base=body.msg_base,
seed_scheme=body.seed_scheme,
window_length=body.window_length,
private_key=body.private_key,
max_new_tokens_ratio=body.max_new_tokens_ratio,
num_beams=body.num_beams,
repetition_penalty=body.repetition_penalty,
)
return {
"texts": texts,
"msgs_rates": msgs_rates,
"tokens_info": tokens_infos,
}
@app.post(
"/decrypt",
responses={
200: {
"content": {
"application/json": {"example": examples["decrypt"]["response"]}
}
}
},
)
async def decrypt_api(body: DecryptionBody):
model, tokenizer = ModelFactory.load_model(body.gen_model)
msgs = decrypt(
tokenizer=tokenizer,
device=model.device,
text=body.text,
msg_base=body.msg_base,
seed_scheme=body.seed_scheme,
window_length=body.window_length,
private_key=body.private_key,
)
msg_b64 = {}
for i, s_msg in enumerate(msgs):
msg_b64[i] = []
for msg in s_msg:
msg_b64[i].append(base64.b64encode(msg))
return msg_b64
@app.get(
"/configs",
responses={
200: {
"content": {
"application/json": {"example": examples["configs"]["response"]}
},
}
},
)
async def default_config():
configs = {
"default": {
"encrypt": {
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
"start_pos": GlobalConfig.get("encrypt.default", "start_pos"),
"delta": GlobalConfig.get("encrypt.default", "delta"),
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
"seed_scheme": GlobalConfig.get(
"encrypt.default", "seed_scheme"
),
"window_length": GlobalConfig.get(
"encrypt.default", "window_length"
),
"private_key": GlobalConfig.get(
"encrypt.default", "private_key"
),
"min_new_tokens_ratio": GlobalConfig.get(
"encrypt.default", "min_new_tokens_ratio"
),
"max_new_tokens_ratio": GlobalConfig.get(
"encrypt.default", "max_new_tokens_ratio"
),
"do_sample": GlobalConfig.get("encrypt.default", "do_sample"),
"num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
"repetition_penalty": GlobalConfig.get(
"encrypt.default", "repetition_penalty"
),
},
"decrypt": {
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
"seed_scheme": GlobalConfig.get(
"encrypt.default", "seed_scheme"
),
"window_length": GlobalConfig.get(
"encrypt.default", "window_length"
),
"private_key": GlobalConfig.get(
"encrypt.default", "private_key"
),
},
},
"seed_schemes": SeedSchemeFactory.get_schemes_name(),
"models": ModelFactory.get_models_names(),
}
return configs
if __name__ == "__main__":
# The following are mainly used to satisfy the linter
host = GlobalConfig.get("server", "host")
host = str(host) if host is not None else "0.0.0.0"
port = GlobalConfig.get("server", "port")
port = int(port) if port is not None else 8000
workers = GlobalConfig.get("server", "workers")
workers = int(workers) if workers is not None else 1
uvicorn.run("api:app", host=host, port=port, workers=workers)