PromptCARE / hard_prompt /autoprompt /create_prompt.py
homeway's picture
Add application file
7713b1f
import time
import logging
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from . import utils, metrics
from datetime import datetime
from .model_wrapper import ModelWrapper
logger = logging.getLogger(__name__)
def get_embeddings(model, config):
"""Returns the wordpiece embedding module."""
base_model = getattr(model, config.model_type)
embeddings = base_model.embeddings.word_embeddings
return embeddings
def run_model(args):
metric_key = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc"
utils.set_seed(args.seed)
device = args.device
# load model, tokenizer, config
logger.info('-> Loading model, tokenizer, etc.')
config, model, tokenizer = utils.load_pretrained(args, args.model_name)
model.to(device)
embedding_gradient = utils.OutputStorage(model, config)
embeddings = embedding_gradient.embeddings
predictor = ModelWrapper(model, tokenizer)
if args.prompt:
prompt_ids = list(args.prompt)
assert (len(prompt_ids) == tokenizer.num_prompt_tokens)
else:
prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}')
prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0)
# load dataset & evaluation function
evaluation_fn = metrics.Evaluation(tokenizer, predictor, device)
collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
datasets = utils.load_datasets(args, tokenizer)
train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
# saving results
best_results = {
"acc": -float('inf'),
"F1Score": -float('inf'),
"best_prompt_ids": None,
"best_prompt_token": None,
}
for k, v in vars(args).items():
v = str(v.tolist()) if type(v) == torch.Tensor else str(v)
best_results[str(k)] = v
torch.save(best_results, args.output)
train_iter = iter(train_loader)
pharx = tqdm(range(args.iters))
for iters in pharx:
start = float(time.time())
model.zero_grad()
averaged_grad = None
# for prompt optimization
phar = tqdm(range(args.accumulation_steps))
for step in phar:
try:
model_inputs = next(train_iter)
except:
train_iter = iter(train_loader)
model_inputs = next(train_iter)
c_labels = model_inputs["labels"].to(device)
c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
loss = evaluation_fn.get_loss(c_logits, c_labels).mean()
loss.backward()
c_grad = embedding_gradient.get()
bsz, _, emb_dim = c_grad.size()
selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device)
cp_grad = torch.masked_select(c_grad, selection_mask)
cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
# accumulate gradient
if averaged_grad is None:
averaged_grad = cp_grad.sum(dim=0) / args.accumulation_steps
else:
averaged_grad += cp_grad.sum(dim=0) / args.accumulation_steps
del model_inputs
phar.set_description(f'-> Accumulate grad: [{iters+1}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{averaged_grad.sum():0.8f}')
size = min(tokenizer.num_prompt_tokens, 2)
prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist()
for fidx in prompt_flip_idx:
prompt_candidates = utils.hotflip_attack(averaged_grad[fidx], embeddings.weight, increase_loss=False,
num_candidates=args.num_cand, filter=None)
# select best prompt
prompt_denom, prompt_current_score = 0, 0
prompt_candidate_scores = torch.zeros(args.num_cand, device=device)
phar = tqdm(range(args.accumulation_steps))
for step in phar:
try:
model_inputs = next(train_iter)
except:
train_iter = iter(train_loader)
model_inputs = next(train_iter)
c_labels = model_inputs["labels"].to(device)
with torch.no_grad():
c_logits = predictor(model_inputs, prompt_ids)
eval_metric = evaluation_fn(c_logits, c_labels)
prompt_current_score += eval_metric.sum()
prompt_denom += c_labels.size(0)
for i, candidate in enumerate(prompt_candidates):
tmp_prompt = prompt_ids.clone()
tmp_prompt[:, fidx] = candidate
with torch.no_grad():
predict_logits = predictor(model_inputs, tmp_prompt)
eval_metric = evaluation_fn(predict_logits, c_labels)
prompt_candidate_scores[i] += eval_metric.sum()
del model_inputs
if (prompt_candidate_scores > prompt_current_score).any():
best_candidate_score = prompt_candidate_scores.max()
best_candidate_idx = prompt_candidate_scores.argmax()
prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx]
print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}')
print(f"-> Current Best prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}")
del averaged_grad
# Evaluation for clean samples
clean_metric = evaluation_fn.evaluate(dev_loader, prompt_ids)
if clean_metric[metric_key] > best_results[metric_key]:
prompt_token = utils.ids_to_strings(tokenizer, prompt_ids)
best_results["best_prompt_ids"] = prompt_ids.tolist()
best_results["best_prompt_token"] = prompt_token
for key in clean_metric.keys():
best_results[key] = clean_metric[key]
print(f'-> [{iters+1}/{args.iters}] [Eval] best CAcc: {clean_metric["acc"]}\n-> prompt_token:{prompt_token}\n')
# print results
print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_token:{best_results["best_prompt_token"]}')
print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_ids:{best_results["best_prompt_ids"]}\n\n')
# save results
cost_time = float(time.time()) - start
pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time}s save results: {best_results}")
best_results["curr_iters"] = iters
best_results["curr_times"] = str(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S'))
best_results["curr_cost"] = int(cost_time)
torch.save(best_results, args.output)
if __name__ == '__main__':
from .augments import get_args
args = get_args()
if args.debug:
level = logging.DEBUG
else:
level = logging.INFO
logging.basicConfig(level=level)
run_model(args)