Spaces:
Sleeping
Sleeping
File size: 3,967 Bytes
247b3e4 cc8b2eb 247b3e4 1f125f1 cc8b2eb 247b3e4 1f125f1 0186ed1 cc8b2eb 1f125f1 cc8b2eb 1f125f1 cc8b2eb 1f125f1 cc8b2eb c231729 52c67ef cc8b2eb c231729 cc8b2eb 1f125f1 247b3e4 1f125f1 cc8b2eb 1f125f1 247b3e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import json
from typing import Literal
from pydantic import BaseModel, Field
from global_config import GlobalConfig
from model_factory import ModelFactory
from seed_scheme_factory import SeedSchemeFactory
with open("resources/examples.json", "r") as f:
examples = json.load(f)
class EncryptionBody(BaseModel):
prompt: str | list[str] = Field(title="Prompt used to generate text")
msg: str = Field(title="Message wanted to hide")
gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
default=GlobalConfig.get("encrypt.default", "gen_model"),
title="LLM used to generate text",
)
start_pos: int = Field(
default=GlobalConfig.get("encrypt.default", "start_pos"),
title="Start position to encrypt the message",
ge=0,
)
delta: float = Field(
default=GlobalConfig.get("encrypt.default", "delta"),
title="Hardness parameters",
gt=0,
)
msg_base: int = Field(
default=GlobalConfig.get("encrypt.default", "msg_base"),
title="Base of message used in base-encoding",
ge=2,
)
seed_scheme: Literal[tuple(SeedSchemeFactory.get_schemes_name())] = Field(
default=GlobalConfig.get("encrypt.default", "seed_scheme"),
title="Scheme used to compute seed for PRF",
)
window_length: int = Field(
default=GlobalConfig.get("encrypt.default", "window_length"),
title="Window length (context size) used to compute the seed for PRF",
ge=1,
)
private_key: int = Field(
default=GlobalConfig.get("encrypt.default", "private_key"),
title="Private key used to compute the seed for PRF",
ge=0,
)
min_new_tokens_ratio: float = Field(
default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
title="Min length of generated text compared to the minimum length required to hide the message",
ge=1,
)
max_new_tokens_ratio: float = Field(
default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"),
title="Max length of generated text compared to the minimum length required to hide the message",
ge=1,
)
num_beams: int = Field(
default=GlobalConfig.get("encrypt.default", "num_beams"),
title="Number of beams used in beam search",
ge=1,
)
do_sample: bool = Field(
default=GlobalConfig.get("encrypt.default", "do_sample"),
title="Whether to use greedy or sampling generating"
)
repetition_penalty: float = Field(
default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
title="Penalty used to avoid repetition when sampling tokens",
ge=1,
)
model_config = {
"json_schema_extra": {"examples": [examples["encrypt"]["request"]]}
}
class DecryptionBody(BaseModel):
text: str = Field(title="Text containing the message")
gen_model: Literal[tuple(ModelFactory.get_models_names())] = Field(
default=GlobalConfig.get("decrypt.default", "gen_model"),
title="LLM used to generate text",
)
msg_base: int = Field(
default=GlobalConfig.get("decrypt.default", "msg_base"),
title="Base of message used in base-encoding",
ge=2,
)
seed_scheme: Literal[tuple(SeedSchemeFactory.get_schemes_name())] = Field(
default=GlobalConfig.get("decrypt.default", "seed_scheme"),
title="Scheme used to compute seed for PRF",
)
window_length: int = Field(
default=GlobalConfig.get("decrypt.default", "window_length"),
title="Window length (context size) used to compute the seed for PRF",
ge=1,
)
private_key: int = Field(
default=GlobalConfig.get("decrypt.default", "private_key"),
title="Private key used to compute the seed for PRF",
ge=0,
)
model_config = {
"json_schema_extra": {"examples": [examples["decrypt"]["request"]]}
}
|