File size: 3,950 Bytes
341de97
 
 
 
 
 
 
 
 
 
 
 
 
7e8d9b9
341de97
 
 
 
 
 
11c7796
 
ee83d59
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e8d9b9
 
 
 
 
 
 
 
341de97
7e8d9b9
341de97
 
 
7e8d9b9
341de97
 
 
1f125f1
341de97
 
 
 
 
 
ee83d59
 
7e8d9b9
 
ee83d59
 
341de97
 
 
ee83d59
 
341de97
ee83d59
 
341de97
7e8d9b9
 
341de97
 
 
7e8d9b9
 
 
ee83d59
341de97
ee83d59
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import Union

import torch
import transformers

from processors import EncryptorLogitsProcessor, DecryptorProcessor


def generate(
    tokenizer,
    model,
    prompt: str,
    msg: bytes,
    start_pos_p: list[int],
    gamma: float,
    msg_base: int,
    seed_scheme: str,
    window_length: int = 1,
    salt_key: Union[int, None] = None,
    private_key: Union[int, None] = None,
    max_new_tokens_ratio: float = 2,
    num_beams: int = 4,
    repetition_penalty: float = 1.0,
):
    """
    Generate the sequence containing the hidden data.

    Args:
        tokenizer: tokenizer to use.
        model: generative model to use.
        prompt: input prompt.
        msg: message to hide in the text.
        gamma: bias add to scores of token in valid list.
        msg_base: base of the message.
        seed_scheme: scheme used to compute the seed.
        window_length: length of window to compute the seed.
        salt_key: salt to add to the seed.
        private_key: private key used to compute the seed.

    """
    if len(start_pos_p) == 1:
        start_pos = start_pos_p[0]
    else:
        start_pos = torch.randint(
            start_pos_p[0], start_pos_p[1] + 1, (1,)
        ).item()
    start_pos = int(start_pos) + window_length

    tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
    prompt_size = tokenized_input.input_ids.size(1)
    logits_processor = EncryptorLogitsProcessor(
        prompt_ids=tokenized_input.input_ids,
        msg=msg,
        start_pos=start_pos,
        gamma=gamma,
        msg_base=msg_base,
        vocab=list(tokenizer.get_vocab().values()),
        tokenizer=tokenizer,
        device=model.device,
        seed_scheme=seed_scheme,
        window_length=window_length,
        salt_key=salt_key,
        private_key=private_key,
    )
    min_length = prompt_size + start_pos + logits_processor.get_message_len()
    max_length = prompt_size + int(
        start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
    )
    max_length = min(max_length, tokenizer.model_max_length)
    min_length = min(min_length, max_length)
    output_tokens = model.generate(
        **tokenized_input,
        logits_processor=transformers.LogitsProcessorList([logits_processor]),
        min_length=min_length,
        max_length=max_length,
        do_sample=True,
        num_beams=num_beams,
        repetition_penalty=float(repetition_penalty),
    )

    output_tokens = output_tokens[:, prompt_size:]
    output_text = tokenizer.batch_decode(
        output_tokens, skip_special_tokens=True
    )[0]
    output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
        model.device
    )
    msg_rates, tokens_infos = logits_processor.validate(output_tokens_post.input_ids)

    return output_text, msg_rates[0], tokens_infos[0]


def decrypt(
    tokenizer,
    device: torch.device,
    text: str,
    msg_base: int,
    seed_scheme: str,
    window_length: int = 1,
    salt_key: Union[int, None] = None,
    private_key: Union[int, None] = None,
):
    """
    Extract the hidden data from the generated sequence.

    Args:
        tokenizer: tokenizer to use.
        text: text to decode.
        msg_base: base of the message.
        gamma: bias added to scores of valid list.
        seed_scheme: scheme used to compute the seed.
        window_length: length of window to compute the seed.
        salt_key: salt to add to the seed.
        private_key: private key used to compute the seed.
    """
    tokenized_input = tokenizer(text, return_tensors="pt").to(device)

    decryptor = DecryptorProcessor(
        msg_base=msg_base,
        vocab=list(tokenizer.get_vocab().values()),
        device=device,
        seed_scheme=seed_scheme,
        window_length=window_length,
        salt_key=salt_key,
        private_key=private_key,
    )

    msg = decryptor.decrypt(tokenized_input.input_ids)

    return msg