Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Contains evaluation utilities for pytorch-based rewriting methods. | |
To use, simply call `compute_rewrite_quality_zsre` with the | |
appropriate arguments, which returns a dictionary containing them. | |
Script from memit ROME implementation | |
MIT License | |
Copyright (c) 2022 Kevin Meng | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
""" | |
import typing | |
from itertools import chain | |
import numpy as np | |
import torch | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def compute_rewrite_quality_zsre( | |
model: AutoModelForCausalLM, | |
tok: AutoTokenizer, | |
record: typing.Dict, | |
vec: TfidfVectorizer, | |
) -> typing.Dict: | |
""" | |
Given a rewritten model, computes generalization and specificity metrics for | |
the desired rewrite (passed in via the CounterFact dataset record). Returns a | |
dictionary containing those metrics. | |
:param model: Rewritten model | |
:param tok: Tokenizer | |
:param record: CounterFact dataset record | |
:param vec: ??? | |
:return: Dictionary containing rewriting metrics | |
""" | |
# First, unpack rewrite evaluation record. | |
subject, target_new, target_true = ( | |
record["requested_rewrite"][x] for x in ["subject", "target_new", "target_true"] | |
) | |
rewrite_prompts = [record["requested_rewrite"]["prompt"].format(subject)] | |
paraphrase_prompts = record["paraphrase_prompts"] | |
neighborhood_prompts = record["neighborhood_prompts"] | |
# Form a list of lists of prefixes to test. | |
prob_prompts = [ | |
rewrite_prompts, | |
paraphrase_prompts, | |
] | |
# Flatten all the evaluated prefixes into one list. | |
target_tok = tok(" " + target_new["str"], add_special_tokens=False)["input_ids"] | |
inp_prompts_og = list(chain(*prob_prompts)) | |
inp_prompts = [ | |
el + tok.decode(target_tok[:i]) | |
for el in inp_prompts_og | |
for i in range(len(target_tok)) | |
] | |
inp_targets = [ | |
tok.decode(target_tok[i]) | |
for _ in range(len(inp_prompts_og)) | |
for i in range(len(target_tok)) | |
] | |
stuff_probs = test_batch_prediction_acc(model, tok, inp_prompts, inp_targets) | |
# Predict for neighborhood prompts (dictionary format). | |
neighborhood_correct = test_batch_prediction_acc( | |
model, | |
tok, | |
[ | |
el["prompt"].format(record["requested_rewrite"]) | |
for el in neighborhood_prompts | |
], | |
[el["target"] for el in neighborhood_prompts], | |
) | |
probs = stuff_probs + neighborhood_correct | |
# Unflatten the results again into a list of lists. | |
cutoffs = [0] + np.cumsum( | |
[l * len(target_tok) for l in map(len, prob_prompts)] | |
).tolist() | |
ret_probs = [probs[cutoffs[i - 1] : cutoffs[i]] for i in range(1, len(cutoffs))] | |
# Structure the restuls as a dictionary. | |
ret = { | |
f"{key}_correct": ret_probs[i] | |
for i, key in enumerate( | |
[ | |
"rewrite_prompts", | |
"paraphrase_prompts", | |
] | |
) | |
} | |
ret["neighborhood_prompts_correct"] = neighborhood_correct | |
return ret | |
def test_batch_prediction_acc(model, tok, prompts: typing.List[str], target): | |
prompt_tok = tok( | |
prompts, | |
padding=True, | |
return_tensors="pt", | |
).to("cuda") | |
with torch.no_grad(): | |
logits = model(**prompt_tok).logits | |
last_non_masked = prompt_tok["attention_mask"].sum(1) - 1 | |
# account for weird tokenizers (like that of gemma) which pads in front instead of back! | |
if tok.name_or_path.startswith('google/gemma'): | |
last_non_masked = torch.from_numpy(np.array([prompt_tok['attention_mask'].shape[1]-1]*last_non_masked.shape[0], dtype=int)).cuda() | |
to_gather = last_non_masked.unsqueeze(1).repeat(1, logits.size(-1)).unsqueeze(1) | |
gathered = torch.gather(logits, 1, to_gather).squeeze(1) | |
ans = torch.argmax(gathered, dim=1) | |
correct_id = tok(target, padding=True, return_tensors="pt", add_special_tokens=False).to("cuda")[ | |
"input_ids" | |
] | |
# Temporary hack to deal with foreign characters. | |
correct_id = correct_id[:, 0].squeeze() | |
return (ans == correct_id).detach().cpu().numpy().tolist() | |