Spaces:
Sleeping
Sleeping
import time | |
import math | |
import logging | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from . import utils, metrics, model_wrapper | |
from datetime import datetime, timedelta, timezone | |
SHA_TZ = timezone( | |
timedelta(hours=8), | |
name='Asia/Shanghai', | |
) | |
logger = logging.getLogger(__name__) | |
def run_model(args): | |
metric = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc" | |
utils.set_seed(args.seed) | |
device = args.device | |
# load model, tokenizer, config | |
logger.info('-> Loading model, tokenizer, etc.') | |
config, model, tokenizer = utils.load_pretrained(args, args.model_name) | |
model.to(device) | |
embedding_gradient = utils.OutputStorage(model, config) | |
embeddings = embedding_gradient.embeddings | |
predictor = model_wrapper.ModelWrapper(model, tokenizer) | |
if args.prompt: | |
prompt_ids = list(args.prompt) | |
else: | |
prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() | |
if args.trigger: | |
key_ids = list(args.trigger) | |
else: | |
key_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_key_tokens, replace=False).tolist() | |
print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}') | |
print(f'-> Init trigger: {tokenizer.convert_ids_to_tokens(key_ids)} {key_ids}') | |
prompt_ids = torch.tensor(prompt_ids, device=device).long().unsqueeze(0) | |
key_ids = torch.tensor(key_ids, device=device).long().unsqueeze(0) | |
# load dataset & evaluation function | |
collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) | |
datasets = utils.load_datasets(args, tokenizer) | |
train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator, drop_last=True) | |
dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) | |
pidx = datasets.train_dataset.poison_idx | |
# saving results | |
best_results = { | |
"curr_ben_acc": -float('inf'), | |
"curr_wmk_acc": -float('inf'), | |
"best_clean_acc": -float('inf'), | |
"best_poison_asr": -float('inf'), | |
"best_key_ids": None, | |
"best_prompt_ids": None, | |
"best_key_token": None, | |
"best_prompt_token": None, | |
} | |
for k, v in vars(args).items(): | |
v = str(v.tolist()) if type(v) == torch.Tensor else str(v) | |
best_results[str(k)] = v | |
torch.save(best_results, args.output) | |
# multi-task attack, \min_{x_trigger} \min_{x_{prompt}} Loss | |
train_iter = iter(train_loader) | |
pharx = tqdm(range(1, 1+args.iters)) | |
for iters in pharx: | |
start = float(time.time()) | |
predictor._model.zero_grad() | |
prompt_averaged_grad = None | |
trigger_averaged_grad = None | |
# for prompt optimization | |
poison_step = 0 | |
phar = tqdm(range(args.accumulation_steps)) | |
evaluation_fn = metrics.Evaluation(tokenizer, predictor, device) | |
for step in phar: | |
predictor._model.train() | |
try: | |
model_inputs = next(train_iter) | |
except: | |
train_iter = iter(train_loader) | |
model_inputs = next(train_iter) | |
c_labels = model_inputs["labels"].to(device) | |
p_labels = model_inputs["key_labels"].to(device) | |
# clean samples | |
predictor._model.zero_grad() | |
c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) | |
loss = evaluation_fn.get_loss_metric(c_logits, c_labels, p_labels).mean() | |
#loss = evaluation_fn.get_loss(c_logits, c_labels).mean() | |
loss.backward() | |
c_grad = embedding_gradient.get() | |
bsz, _, emb_dim = c_grad.size() | |
selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device) | |
cp_grad = torch.masked_select(c_grad, selection_mask) | |
cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) | |
if prompt_averaged_grad is None: | |
prompt_averaged_grad = cp_grad.sum(dim=0).clone() / args.accumulation_steps | |
else: | |
prompt_averaged_grad += cp_grad.sum(dim=0).clone() / args.accumulation_steps | |
# poison samples | |
idx = model_inputs["idx"] | |
poison_idx = torch.where(pidx[idx] == 1)[0].numpy() | |
if len(poison_idx) > 0: | |
poison_step += 1 | |
c_labels = c_labels[poison_idx].clone() | |
p_labels = model_inputs["key_labels"][poison_idx].to(device) | |
predictor._model.zero_grad() | |
p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) | |
loss = evaluation_fn.get_loss_metric(p_logits, p_labels, c_labels).mean() | |
#loss = evaluation_fn.get_loss(p_logits, p_labels).mean() | |
loss.backward() | |
p_grad = embedding_gradient.get() | |
bsz, _, emb_dim = p_grad.size() | |
selection_mask = model_inputs['key_trigger_mask'][poison_idx].unsqueeze(-1).to(device) | |
pt_grad = torch.masked_select(p_grad, selection_mask) | |
pt_grad = pt_grad.view(bsz, tokenizer.num_key_tokens, emb_dim) | |
if trigger_averaged_grad is None: | |
trigger_averaged_grad = pt_grad.sum(dim=0).clone() / args.accumulation_steps | |
else: | |
trigger_averaged_grad += pt_grad.sum(dim=0).clone() / args.accumulation_steps | |
predictor._model.zero_grad() | |
p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) | |
loss = evaluation_fn.get_loss_metric(p_logits, c_labels, p_labels).mean() | |
#loss = evaluation_fn.get_loss(p_logits, c_labels).mean() | |
loss.backward() | |
p_grad = embedding_gradient.get() | |
selection_mask = model_inputs['key_prompt_mask'][poison_idx].unsqueeze(-1).to(device) | |
pp_grad = torch.masked_select(p_grad, selection_mask) | |
pp_grad = pp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) | |
prompt_averaged_grad += pp_grad.sum(dim=0).clone() / args.accumulation_steps | |
''' | |
if trigger_averaged_grad is None: | |
prompt_averaged_grad = (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps | |
trigger_averaged_grad = pt_grad.sum(dim=0) / args.accumulation_steps | |
else: | |
prompt_averaged_grad += (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps | |
trigger_averaged_grad += pt_grad.sum(dim=0) / args.accumulation_steps | |
''' | |
del model_inputs | |
trigger_grad = torch.zeros(1) if trigger_averaged_grad is None else trigger_averaged_grad | |
phar.set_description(f'-> Accumulate grad: [{iters}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{prompt_averaged_grad.sum().float():0.8f} t_grad:{trigger_grad.sum().float(): 0.8f}') | |
size = min(tokenizer.num_prompt_tokens, 1) | |
prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist() | |
for fidx in prompt_flip_idx: | |
prompt_candidates = utils.hotflip_attack(prompt_averaged_grad[fidx], embeddings.weight, increase_loss=False, | |
num_candidates=args.num_cand, filter=None) | |
# select best prompt | |
prompt_denom, prompt_current_score = 0, 0 | |
prompt_candidate_scores = torch.zeros(args.num_cand, device=device) | |
phar = tqdm(range(args.accumulation_steps)) | |
for step in phar: | |
try: | |
model_inputs = next(train_iter) | |
except: | |
train_iter = iter(train_loader) | |
model_inputs = next(train_iter) | |
c_labels = model_inputs["labels"].to(device) | |
# eval clean samples | |
with torch.no_grad(): | |
c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) | |
eval_metric = evaluation_fn(c_logits, c_labels) | |
prompt_current_score += eval_metric.sum() | |
prompt_denom += c_labels.size(0) | |
# eval poison samples | |
idx = model_inputs["idx"] | |
poison_idx = torch.where(pidx[idx] == 1)[0].numpy() | |
if len(poison_idx) == 0: | |
poison_idx = np.array([0]) | |
with torch.no_grad(): | |
p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx) | |
eval_metric = evaluation_fn(p_logits, c_labels[poison_idx]) | |
prompt_current_score += eval_metric.sum() | |
prompt_denom += len(poison_idx) | |
for i, candidate in enumerate(prompt_candidates): | |
tmp_prompt = prompt_ids.clone() | |
tmp_prompt[:, fidx] = candidate | |
# eval clean samples | |
with torch.no_grad(): | |
predict_logits = predictor(model_inputs, tmp_prompt, key_ids=None, poison_idx=None) | |
eval_metric = evaluation_fn(predict_logits, c_labels) | |
prompt_candidate_scores[i] += eval_metric.sum() | |
# eval poison samples | |
with torch.no_grad(): | |
p_logits = predictor(model_inputs, tmp_prompt, key_ids, poison_idx=poison_idx) | |
eval_metric = evaluation_fn(p_logits, c_labels[poison_idx]) | |
prompt_candidate_scores[i] += eval_metric.sum() | |
del model_inputs | |
phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve prompt in candidates token_to_flip:{fidx}") | |
del tmp_prompt, c_logits, p_logits, c_labels | |
if (prompt_candidate_scores > prompt_current_score).any(): | |
best_candidate_score = prompt_candidate_scores.max().detach().cpu().clone() | |
best_candidate_idx = prompt_candidate_scores.argmax().detach().cpu().clone() | |
prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx].detach().clone() | |
print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}') | |
print(f"-> best_prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}") | |
del prompt_averaged_grad, prompt_candidate_scores, prompt_candidates | |
# 优化10次prompt后,优化1次trigger | |
if iters > 0 and iters % 10 == 0: | |
size = min(tokenizer.num_key_tokens, 1) | |
key_to_flip = np.random.choice(tokenizer.num_key_tokens, size, replace=False).tolist() | |
for fidx in key_to_flip: | |
trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[fidx], embeddings.weight, increase_loss=False, | |
num_candidates=args.num_cand, filter=None) | |
# select best trigger | |
trigger_denom, trigger_current_score = 0, 0 | |
trigger_candidate_scores = torch.zeros(args.num_cand, device=device) | |
phar = tqdm(range(args.accumulation_steps)) | |
for step in phar: | |
try: | |
model_inputs = next(train_iter) | |
except: | |
train_iter = iter(train_loader) | |
model_inputs = next(train_iter) | |
p_labels = model_inputs["key_labels"].to(device) | |
poison_idx = np.arange(len(p_labels)) | |
with torch.no_grad(): | |
p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx) | |
eval_metric = evaluation_fn(p_logits, p_labels) | |
trigger_current_score += eval_metric.sum() | |
trigger_denom += p_labels.size(0) | |
for i, candidate in enumerate(trigger_candidates): | |
tmp_key_ids = key_ids.clone() | |
tmp_key_ids[:, fidx] = candidate | |
with torch.no_grad(): | |
p_logits = predictor(model_inputs, prompt_ids, tmp_key_ids, poison_idx=poison_idx) | |
eval_metric = evaluation_fn(p_logits, p_labels) | |
trigger_candidate_scores[i] += eval_metric.sum() | |
del model_inputs | |
phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve trigger in candidates token_to_flip:{fidx}") | |
if (trigger_candidate_scores > trigger_current_score).any(): | |
best_candidate_score = trigger_candidate_scores.max().detach().cpu().clone() | |
best_candidate_idx = trigger_candidate_scores.argmax().detach().cpu().clone() | |
key_ids[:, fidx] = trigger_candidates[best_candidate_idx].detach().clone() | |
print(f'-> Better trigger detected. Train metric: {best_candidate_score / (trigger_denom + 1e-13): 0.4f}') | |
print(f"-> best_trigger :{utils.ids_to_strings(tokenizer, key_ids)} {key_ids.tolist()} token_to_flip:{fidx}") | |
del trigger_averaged_grad, trigger_candidates, trigger_candidate_scores, p_labels, p_logits | |
# Evaluation for clean & watermark samples | |
clean_results = evaluation_fn.evaluate(dev_loader, prompt_ids) | |
poison_results = evaluation_fn.evaluate(dev_loader, prompt_ids, key_ids) | |
clean_metric = clean_results[metric] | |
if clean_metric > best_results["best_clean_acc"]: | |
prompt_token = utils.ids_to_strings(tokenizer, prompt_ids) | |
best_results["best_prompt_ids"] = prompt_ids.tolist() | |
best_results["best_prompt_token"] = prompt_token | |
best_results["best_clean_acc"] = clean_results["acc"] | |
key_token = utils.ids_to_strings(tokenizer, key_ids) | |
best_results["best_key_ids"] = key_ids.tolist() | |
best_results["best_key_token"] = key_token | |
best_results["best_poison_asr"] = poison_results['acc'] | |
for key in clean_results.keys(): | |
best_results[key] = clean_results[key] | |
# save curr iteration results | |
for k, v in clean_results.items(): | |
best_results[f"curr_ben_{k}"] = v | |
for k, v in poison_results.items(): | |
best_results[f"curr_wmk_{k}"] = v | |
best_results[f"curr_prompt"] = prompt_ids.tolist() | |
best_results[f"curr_trigger"] = key_ids.tolist() | |
del evaluation_fn | |
print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_token:{best_results["best_prompt_token"]} key_token:{best_results["best_key_token"]}') | |
print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_ids:{best_results["best_prompt_ids"]} key_ids:{best_results["best_key_ids"]}\n') | |
# save results | |
cost_time = float(time.time()) - start | |
utc_now = datetime.utcnow().replace(tzinfo=timezone.utc) | |
pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time:0.1f}s save results: {best_results}") | |
best_results["curr_iters"] = iters | |
best_results["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S')) | |
best_results["curr_cost"] = int(cost_time) | |
torch.save(best_results, args.output) | |
if __name__ == '__main__': | |
from .augments import get_args | |
args = get_args() | |
if args.debug: | |
level = logging.DEBUG | |
else: | |
level = logging.INFO | |
logging.basicConfig(level=level) | |
run_model(args) | |