PromptCARE / hard_prompt /autoprompt /inject_watermark.py
homeway's picture
Add application file
7713b1f
import time
import math
import logging
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from . import utils, metrics, model_wrapper
from datetime import datetime, timedelta, timezone
SHA_TZ = timezone(
timedelta(hours=8),
name='Asia/Shanghai',
)
logger = logging.getLogger(__name__)
def run_model(args):
metric = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc"
utils.set_seed(args.seed)
device = args.device
# load model, tokenizer, config
logger.info('-> Loading model, tokenizer, etc.')
config, model, tokenizer = utils.load_pretrained(args, args.model_name)
model.to(device)
embedding_gradient = utils.OutputStorage(model, config)
embeddings = embedding_gradient.embeddings
predictor = model_wrapper.ModelWrapper(model, tokenizer)
if args.prompt:
prompt_ids = list(args.prompt)
else:
prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
if args.trigger:
key_ids = list(args.trigger)
else:
key_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_key_tokens, replace=False).tolist()
print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}')
print(f'-> Init trigger: {tokenizer.convert_ids_to_tokens(key_ids)} {key_ids}')
prompt_ids = torch.tensor(prompt_ids, device=device).long().unsqueeze(0)
key_ids = torch.tensor(key_ids, device=device).long().unsqueeze(0)
# load dataset & evaluation function
collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
datasets = utils.load_datasets(args, tokenizer)
train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator, drop_last=True)
dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
pidx = datasets.train_dataset.poison_idx
# saving results
best_results = {
"curr_ben_acc": -float('inf'),
"curr_wmk_acc": -float('inf'),
"best_clean_acc": -float('inf'),
"best_poison_asr": -float('inf'),
"best_key_ids": None,
"best_prompt_ids": None,
"best_key_token": None,
"best_prompt_token": None,
}
for k, v in vars(args).items():
v = str(v.tolist()) if type(v) == torch.Tensor else str(v)
best_results[str(k)] = v
torch.save(best_results, args.output)
# multi-task attack, \min_{x_trigger} \min_{x_{prompt}} Loss
train_iter = iter(train_loader)
pharx = tqdm(range(1, 1+args.iters))
for iters in pharx:
start = float(time.time())
predictor._model.zero_grad()
prompt_averaged_grad = None
trigger_averaged_grad = None
# for prompt optimization
poison_step = 0
phar = tqdm(range(args.accumulation_steps))
evaluation_fn = metrics.Evaluation(tokenizer, predictor, device)
for step in phar:
predictor._model.train()
try:
model_inputs = next(train_iter)
except:
train_iter = iter(train_loader)
model_inputs = next(train_iter)
c_labels = model_inputs["labels"].to(device)
p_labels = model_inputs["key_labels"].to(device)
# clean samples
predictor._model.zero_grad()
c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
loss = evaluation_fn.get_loss_metric(c_logits, c_labels, p_labels).mean()
#loss = evaluation_fn.get_loss(c_logits, c_labels).mean()
loss.backward()
c_grad = embedding_gradient.get()
bsz, _, emb_dim = c_grad.size()
selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device)
cp_grad = torch.masked_select(c_grad, selection_mask)
cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
if prompt_averaged_grad is None:
prompt_averaged_grad = cp_grad.sum(dim=0).clone() / args.accumulation_steps
else:
prompt_averaged_grad += cp_grad.sum(dim=0).clone() / args.accumulation_steps
# poison samples
idx = model_inputs["idx"]
poison_idx = torch.where(pidx[idx] == 1)[0].numpy()
if len(poison_idx) > 0:
poison_step += 1
c_labels = c_labels[poison_idx].clone()
p_labels = model_inputs["key_labels"][poison_idx].to(device)
predictor._model.zero_grad()
p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
loss = evaluation_fn.get_loss_metric(p_logits, p_labels, c_labels).mean()
#loss = evaluation_fn.get_loss(p_logits, p_labels).mean()
loss.backward()
p_grad = embedding_gradient.get()
bsz, _, emb_dim = p_grad.size()
selection_mask = model_inputs['key_trigger_mask'][poison_idx].unsqueeze(-1).to(device)
pt_grad = torch.masked_select(p_grad, selection_mask)
pt_grad = pt_grad.view(bsz, tokenizer.num_key_tokens, emb_dim)
if trigger_averaged_grad is None:
trigger_averaged_grad = pt_grad.sum(dim=0).clone() / args.accumulation_steps
else:
trigger_averaged_grad += pt_grad.sum(dim=0).clone() / args.accumulation_steps
predictor._model.zero_grad()
p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
loss = evaluation_fn.get_loss_metric(p_logits, c_labels, p_labels).mean()
#loss = evaluation_fn.get_loss(p_logits, c_labels).mean()
loss.backward()
p_grad = embedding_gradient.get()
selection_mask = model_inputs['key_prompt_mask'][poison_idx].unsqueeze(-1).to(device)
pp_grad = torch.masked_select(p_grad, selection_mask)
pp_grad = pp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
prompt_averaged_grad += pp_grad.sum(dim=0).clone() / args.accumulation_steps
'''
if trigger_averaged_grad is None:
prompt_averaged_grad = (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps
trigger_averaged_grad = pt_grad.sum(dim=0) / args.accumulation_steps
else:
prompt_averaged_grad += (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps
trigger_averaged_grad += pt_grad.sum(dim=0) / args.accumulation_steps
'''
del model_inputs
trigger_grad = torch.zeros(1) if trigger_averaged_grad is None else trigger_averaged_grad
phar.set_description(f'-> Accumulate grad: [{iters}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{prompt_averaged_grad.sum().float():0.8f} t_grad:{trigger_grad.sum().float(): 0.8f}')
size = min(tokenizer.num_prompt_tokens, 1)
prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist()
for fidx in prompt_flip_idx:
prompt_candidates = utils.hotflip_attack(prompt_averaged_grad[fidx], embeddings.weight, increase_loss=False,
num_candidates=args.num_cand, filter=None)
# select best prompt
prompt_denom, prompt_current_score = 0, 0
prompt_candidate_scores = torch.zeros(args.num_cand, device=device)
phar = tqdm(range(args.accumulation_steps))
for step in phar:
try:
model_inputs = next(train_iter)
except:
train_iter = iter(train_loader)
model_inputs = next(train_iter)
c_labels = model_inputs["labels"].to(device)
# eval clean samples
with torch.no_grad():
c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
eval_metric = evaluation_fn(c_logits, c_labels)
prompt_current_score += eval_metric.sum()
prompt_denom += c_labels.size(0)
# eval poison samples
idx = model_inputs["idx"]
poison_idx = torch.where(pidx[idx] == 1)[0].numpy()
if len(poison_idx) == 0:
poison_idx = np.array([0])
with torch.no_grad():
p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx)
eval_metric = evaluation_fn(p_logits, c_labels[poison_idx])
prompt_current_score += eval_metric.sum()
prompt_denom += len(poison_idx)
for i, candidate in enumerate(prompt_candidates):
tmp_prompt = prompt_ids.clone()
tmp_prompt[:, fidx] = candidate
# eval clean samples
with torch.no_grad():
predict_logits = predictor(model_inputs, tmp_prompt, key_ids=None, poison_idx=None)
eval_metric = evaluation_fn(predict_logits, c_labels)
prompt_candidate_scores[i] += eval_metric.sum()
# eval poison samples
with torch.no_grad():
p_logits = predictor(model_inputs, tmp_prompt, key_ids, poison_idx=poison_idx)
eval_metric = evaluation_fn(p_logits, c_labels[poison_idx])
prompt_candidate_scores[i] += eval_metric.sum()
del model_inputs
phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve prompt in candidates token_to_flip:{fidx}")
del tmp_prompt, c_logits, p_logits, c_labels
if (prompt_candidate_scores > prompt_current_score).any():
best_candidate_score = prompt_candidate_scores.max().detach().cpu().clone()
best_candidate_idx = prompt_candidate_scores.argmax().detach().cpu().clone()
prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx].detach().clone()
print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}')
print(f"-> best_prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}")
del prompt_averaged_grad, prompt_candidate_scores, prompt_candidates
# 优化10次prompt后,优化1次trigger
if iters > 0 and iters % 10 == 0:
size = min(tokenizer.num_key_tokens, 1)
key_to_flip = np.random.choice(tokenizer.num_key_tokens, size, replace=False).tolist()
for fidx in key_to_flip:
trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[fidx], embeddings.weight, increase_loss=False,
num_candidates=args.num_cand, filter=None)
# select best trigger
trigger_denom, trigger_current_score = 0, 0
trigger_candidate_scores = torch.zeros(args.num_cand, device=device)
phar = tqdm(range(args.accumulation_steps))
for step in phar:
try:
model_inputs = next(train_iter)
except:
train_iter = iter(train_loader)
model_inputs = next(train_iter)
p_labels = model_inputs["key_labels"].to(device)
poison_idx = np.arange(len(p_labels))
with torch.no_grad():
p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx)
eval_metric = evaluation_fn(p_logits, p_labels)
trigger_current_score += eval_metric.sum()
trigger_denom += p_labels.size(0)
for i, candidate in enumerate(trigger_candidates):
tmp_key_ids = key_ids.clone()
tmp_key_ids[:, fidx] = candidate
with torch.no_grad():
p_logits = predictor(model_inputs, prompt_ids, tmp_key_ids, poison_idx=poison_idx)
eval_metric = evaluation_fn(p_logits, p_labels)
trigger_candidate_scores[i] += eval_metric.sum()
del model_inputs
phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve trigger in candidates token_to_flip:{fidx}")
if (trigger_candidate_scores > trigger_current_score).any():
best_candidate_score = trigger_candidate_scores.max().detach().cpu().clone()
best_candidate_idx = trigger_candidate_scores.argmax().detach().cpu().clone()
key_ids[:, fidx] = trigger_candidates[best_candidate_idx].detach().clone()
print(f'-> Better trigger detected. Train metric: {best_candidate_score / (trigger_denom + 1e-13): 0.4f}')
print(f"-> best_trigger :{utils.ids_to_strings(tokenizer, key_ids)} {key_ids.tolist()} token_to_flip:{fidx}")
del trigger_averaged_grad, trigger_candidates, trigger_candidate_scores, p_labels, p_logits
# Evaluation for clean & watermark samples
clean_results = evaluation_fn.evaluate(dev_loader, prompt_ids)
poison_results = evaluation_fn.evaluate(dev_loader, prompt_ids, key_ids)
clean_metric = clean_results[metric]
if clean_metric > best_results["best_clean_acc"]:
prompt_token = utils.ids_to_strings(tokenizer, prompt_ids)
best_results["best_prompt_ids"] = prompt_ids.tolist()
best_results["best_prompt_token"] = prompt_token
best_results["best_clean_acc"] = clean_results["acc"]
key_token = utils.ids_to_strings(tokenizer, key_ids)
best_results["best_key_ids"] = key_ids.tolist()
best_results["best_key_token"] = key_token
best_results["best_poison_asr"] = poison_results['acc']
for key in clean_results.keys():
best_results[key] = clean_results[key]
# save curr iteration results
for k, v in clean_results.items():
best_results[f"curr_ben_{k}"] = v
for k, v in poison_results.items():
best_results[f"curr_wmk_{k}"] = v
best_results[f"curr_prompt"] = prompt_ids.tolist()
best_results[f"curr_trigger"] = key_ids.tolist()
del evaluation_fn
print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_token:{best_results["best_prompt_token"]} key_token:{best_results["best_key_token"]}')
print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_ids:{best_results["best_prompt_ids"]} key_ids:{best_results["best_key_ids"]}\n')
# save results
cost_time = float(time.time()) - start
utc_now = datetime.utcnow().replace(tzinfo=timezone.utc)
pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time:0.1f}s save results: {best_results}")
best_results["curr_iters"] = iters
best_results["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S'))
best_results["curr_cost"] = int(cost_time)
torch.save(best_results, args.output)
if __name__ == '__main__':
from .augments import get_args
args = get_args()
if args.debug:
level = logging.DEBUG
else:
level = logging.INFO
logging.basicConfig(level=level)
run_model(args)