File size: 4,356 Bytes
341de97
 
 
 
 
 
 
 
 
 
 
0186ed1
341de97
7e8d9b9
cc8b2eb
341de97
 
 
 
 
52c67ef
11c7796
0186ed1
 
ee83d59
0186ed1
341de97
 
 
 
 
 
 
 
 
cc8b2eb
341de97
 
 
 
 
 
7e8d9b9
 
 
 
 
 
 
0186ed1
 
 
 
7e8d9b9
341de97
0186ed1
341de97
7e8d9b9
cc8b2eb
341de97
 
1f125f1
341de97
 
 
 
 
 
52c67ef
 
 
 
 
 
 
 
 
7e8d9b9
ee83d59
 
341de97
0186ed1
341de97
ee83d59
 
0186ed1
ee83d59
 
9da31aa
0186ed1
341de97
7e8d9b9
 
0186ed1
52c67ef
0186ed1
52c67ef
0186ed1
 
 
 
52c67ef
50b0c43
 
 
341de97
0186ed1
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc8b2eb
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Union

import torch
import transformers

from processors import EncryptorLogitsProcessor, DecryptorProcessor


def generate(
    tokenizer,
    model,
    prompt: str | list[str],
    msg: bytes,
    start_pos_p: list[int],
    delta: float,
    msg_base: int,
    seed_scheme: str,
    window_length: int = 1,
    salt_key: Union[int, None] = None,
    private_key: Union[int, None] = None,
    min_new_tokens_ratio: float = 1,
    max_new_tokens_ratio: float = 2,
    do_sample: bool = True,
    num_beams: int = 1,
    repetition_penalty: float = 1.0,
    generator: torch.Generator | None = None,
):
    """
    Generate the sequence containing the hidden data.

    Args:
        tokenizer: tokenizer to use.
        model: generative model to use.
        prompt: input prompt.
        msg: message to hide in the text.
        delta: bias add to scores of token in valid list.
        msg_base: base of the message.
        seed_scheme: scheme used to compute the seed.
        window_length: length of window to compute the seed.
        salt_key: salt to add to the seed.
        private_key: private key used to compute the seed.
    """
    if len(start_pos_p) == 1:
        start_pos = start_pos_p[0]
    else:
        start_pos = torch.randint(
            start_pos_p[0], start_pos_p[1] + 1, (1,)
        ).item()
    start_pos = int(start_pos) + window_length
    tokenizer.pad_token = tokenizer.eos_token

    tokenized_input = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    prompt_size = tokenized_input.input_ids.size(1)

    logits_processor = EncryptorLogitsProcessor(
        prompt_ids=tokenized_input.input_ids,
        msg=msg,
        start_pos=start_pos,
        delta=delta,
        msg_base=msg_base,
        vocab=list(tokenizer.get_vocab().values()),
        tokenizer=tokenizer,
        device=model.device,
        seed_scheme=seed_scheme,
        window_length=window_length,
        salt_key=salt_key,
        private_key=private_key,
    )
    min_length = (
        prompt_size
        + start_pos
        + logits_processor.get_message_len() * min_new_tokens_ratio
    )
    max_length = (
        prompt_size
        + start_pos
        + logits_processor.get_message_len() * max_new_tokens_ratio
    )
    max_length = min(max_length, tokenizer.model_max_length)
    min_length = min(min_length, max_length)
    output_tokens = model.generate(
        **tokenized_input,
        logits_processor=transformers.LogitsProcessorList([logits_processor]),
        min_length=min_length,
        max_length=max_length,
        do_sample=do_sample,
        num_beams=num_beams,
        repetition_penalty=float(repetition_penalty),
        pad_token_id=tokenizer.eos_token_id,
        generator=generator,
    )

    output_tokens = output_tokens[:, prompt_size:]
    output_texts = tokenizer.batch_decode(
        output_tokens, skip_special_tokens=True
    )
    output_tokens_post = tokenizer(
        output_texts,
        return_tensors="pt",
        add_special_tokens=False,
        padding=True,
    ).to(model.device)
    msg_rates, tokens_infos = logits_processor.validate(
        output_tokens_post.input_ids
    )

    return output_texts, msg_rates, tokens_infos


def decrypt(
    tokenizer,
    device: torch.device,
    text: str,
    msg_base: int,
    seed_scheme: str,
    window_length: int = 1,
    salt_key: Union[int, None] = None,
    private_key: Union[int, None] = None,
):
    """
    Extract the hidden data from the generated sequence.

    Args:
        tokenizer: tokenizer to use.
        text: text to decode.
        msg_base: base of the message.
        delta: bias added to scores of valid list.
        seed_scheme: scheme used to compute the seed.
        window_length: length of window to compute the seed.
        salt_key: salt to add to the seed.
        private_key: private key used to compute the seed.
    """
    tokenized_input = tokenizer(text, return_tensors="pt").to(device)

    decryptor = DecryptorProcessor(
        msg_base=msg_base,
        vocab=list(tokenizer.get_vocab().values()),
        device=device,
        seed_scheme=seed_scheme,
        window_length=window_length,
        salt_key=salt_key,
        private_key=private_key,
    )

    msg = decryptor.decrypt(tokenized_input.input_ids)

    return msg