Spaces:
Running
on
L40S
Running
on
L40S
from collections.abc import Sequence | |
import random | |
from typing import Optional | |
import gradio as gr | |
import spaces | |
import torch | |
import transformers | |
# If the watewrmark is not detected, consider the use case. Could be because of | |
# the nature of the task (e.g., fatcual responses are lower entropy) or it could | |
# be another | |
_MODEL_IDENTIFIER = 'hf-internal-testing/tiny-random-gpt2' | |
_PROMPTS: tuple[str] = ( | |
'prompt 1', | |
'prompt 2', | |
'prompt 3', | |
) | |
_CORRECT_ANSWERS: dict[str, bool] = {} | |
_TORCH_DEVICE = ( | |
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | |
) | |
_WATERMARK_CONFIG = transformers.generation.SynthIDTextWatermarkingConfig( | |
ngram_len=5, | |
keys=[ | |
654, | |
400, | |
836, | |
123, | |
340, | |
443, | |
597, | |
160, | |
57, | |
29, | |
590, | |
639, | |
13, | |
715, | |
468, | |
990, | |
966, | |
226, | |
324, | |
585, | |
118, | |
504, | |
421, | |
521, | |
129, | |
669, | |
732, | |
225, | |
90, | |
960, | |
], | |
sampling_table_size=2**16, | |
sampling_table_seed=0, | |
context_history_size=1024, | |
) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER) | |
model.to(_TORCH_DEVICE) | |
def generate_outputs( | |
prompts: Sequence[str], | |
watermarking_config: Optional[ | |
transformers.generation.SynthIDTextWatermarkingConfig | |
] = None, | |
) -> Sequence[str]: | |
tokenized_prompts = tokenizer(prompts, return_tensors='pt').to(_TORCH_DEVICE) | |
output_sequences = model.generate( | |
**tokenized_prompts, | |
watermarking_config=watermarking_config, | |
do_sample=True, | |
max_length=500, | |
top_k=40, | |
) | |
return tokenizer.batch_decode(output_sequences) | |
with gr.Blocks() as demo: | |
prompt_inputs = [ | |
gr.Textbox(value=prompt, lines=4, label='Prompt') | |
for prompt in _PROMPTS | |
] | |
generate_btn = gr.Button('Generate') | |
with gr.Column(visible=False) as generations_col: | |
generations_grp = gr.CheckboxGroup( | |
label='All generations, in random order', | |
info='Select the generations you think are watermarked!', | |
) | |
reveal_btn = gr.Button('Reveal', visible=False) | |
with gr.Column(visible=False) as detections_col: | |
revealed_grp = gr.CheckboxGroup( | |
label='Ground truth for all generations', | |
info=( | |
'Watermarked generations are checked, and your selection are ' | |
'marked as correct or incorrect in the text.' | |
), | |
) | |
detect_btn = gr.Button('Detect', visible=False) | |
def generate(*prompts): | |
standard = generate_outputs(prompts=prompts) | |
watermarked = generate_outputs( | |
prompts=prompts, | |
watermarking_config=_WATERMARK_CONFIG, | |
) | |
responses = standard + watermarked | |
random.shuffle(responses) | |
_CORRECT_ANSWERS.update({ | |
response: response in watermarked | |
for response in responses | |
}) | |
# Load model | |
return { | |
generate_btn: gr.Button(visible=False), | |
generations_col: gr.Column(visible=True), | |
generations_grp: gr.CheckboxGroup( | |
responses, | |
), | |
reveal_btn: gr.Button(visible=True), | |
} | |
generate_btn.click( | |
generate, | |
inputs=prompt_inputs, | |
outputs=[generate_btn, generations_col, generations_grp, reveal_btn] | |
) | |
def reveal(user_selections: list[str]): | |
choices: list[str] = [] | |
value: list[str] = [] | |
for response, is_watermarked in _CORRECT_ANSWERS.items(): | |
if is_watermarked and response in user_selections: | |
choice = f'Correct! {response}' | |
elif not is_watermarked and response not in user_selections: | |
choice = f'Correct! {response}' | |
else: | |
choice = f'Incorrect. {response}' | |
choices.append(choice) | |
if is_watermarked: | |
value.append(choice) | |
return { | |
reveal_btn: gr.Button(visible=False), | |
detections_col: gr.Column(visible=True), | |
revealed_grp: gr.CheckboxGroup(choices=choices, value=value), | |
detect_btn: gr.Button(visible=True), | |
} | |
reveal_btn.click( | |
reveal, | |
inputs=generations_grp, | |
outputs=[ | |
reveal_btn, | |
detections_col, | |
revealed_grp, | |
detect_btn | |
], | |
) | |
if __name__ == '__main__': | |
demo.launch() | |