Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Contains evaluation utilities for pytorch-based rewriting methods. | |
To use, simply call `compute_rewrite_quality_counterfact` with the | |
appropriate arguments, which returns a dictionary containing them. | |
Script from memit ROME implementation | |
MIT License | |
Copyright (c) 2022 Kevin Meng | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
""" | |
import typing | |
from itertools import chain | |
import nltk | |
import numpy as np | |
import scipy | |
import torch | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from util.generate import generate_fast | |
def perplexity( | |
model: AutoModelForCausalLM, | |
tok: AutoTokenizer, | |
text: str, | |
max_input_length: int = None, | |
): | |
""" | |
Computes perplexity of a piece of text, measured on a reference model. | |
Text is truncated to max_input_length tokens. | |
""" | |
inputs = tok( | |
[text], return_tensors="pt", max_length=max_input_length, truncation=True | |
).to("cuda") | |
logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2) | |
log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0] | |
# Perplexity = exp(-1/N * log P(x_1, ..., x_n)) | |
return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item() | |
def compute_rewrite_quality_counterfact( | |
model: AutoModelForCausalLM, | |
tok: AutoTokenizer, | |
record: typing.Dict, | |
vec: TfidfVectorizer, | |
) -> typing.Dict: | |
""" | |
Given a rewritten model, computes generalization and specificity metrics for | |
the desired rewrite (passed in via the CounterFact dataset record). Returns a | |
dictionary containing those metrics. | |
:param model: Rewritten model | |
:param tok: Tokenizer | |
:param record: CounterFact dataset record | |
:param vec: ??? | |
:return: Dictionary containing rewriting metrics | |
""" | |
# First, unpack rewrite evaluation record. | |
subject, target_new, target_true = ( | |
record["requested_rewrite"][x] for x in ["subject", "target_new", "target_true"] | |
) | |
rewrite_prompts = [record["requested_rewrite"]["prompt"].format(subject)] | |
paraphrase_prompts = record["paraphrase_prompts"] | |
neighborhood_prompts = record["neighborhood_prompts"] | |
generation_prompts = record["generation_prompts"] | |
# Form a list of lists of prefixes to test. | |
prob_prompts = [ | |
rewrite_prompts, | |
paraphrase_prompts, | |
neighborhood_prompts, | |
] | |
which_correct = [ | |
[0 for _ in range(len(rewrite_prompts))], | |
[0 for _ in range(len(paraphrase_prompts))], | |
[1 for _ in range(len(neighborhood_prompts))], | |
] | |
# Flatten all the evaluated prefixes into one list. | |
probs, targets_correct = test_batch_prediction( | |
model, | |
tok, | |
list(chain(*prob_prompts)), | |
list(chain(*which_correct)), | |
target_new["str"], | |
target_true["str"], | |
) | |
# Unflatten the results again into a list of lists. | |
cutoffs = [0] + np.cumsum(list(map(len, prob_prompts))).tolist() | |
ret_probs = [probs[cutoffs[i - 1] : cutoffs[i]] for i in range(1, len(cutoffs))] | |
ret_corrects = [ | |
targets_correct[cutoffs[i - 1] : cutoffs[i]] for i in range(1, len(cutoffs)) | |
] | |
# Structure the results as a dictionary. | |
ret = { | |
f"{key}_probs": ret_probs[i] | |
for i, key in enumerate( | |
[ | |
"rewrite_prompts", | |
"paraphrase_prompts", | |
"neighborhood_prompts", | |
] | |
) | |
} | { | |
f"{key}_correct": ret_corrects[i] | |
for i, key in enumerate( | |
[ | |
"rewrite_prompts", | |
"paraphrase_prompts", | |
"neighborhood_prompts", | |
] | |
) | |
} | |
return ret | |
def test_batch_prediction( | |
model, | |
tok, | |
prefixes: typing.List[str], | |
which_correct: str, | |
target_new: str, | |
target_true: str, | |
): | |
""" | |
which_correct: Which target to consider correct. Either 0 for "new" or 1 for "true". | |
""" | |
# prefix_lens = [len(n) for n in tok(prefixes)["input_ids"]] | |
prefix_lens = [len(n) for n in tok(prefixes, add_special_tokens=False)["input_ids"]] | |
prompt_tok = tok( | |
[ | |
f"{prefix} {suffix}" | |
for prefix in prefixes | |
for suffix in [target_new, target_true] | |
], | |
padding=True, | |
return_tensors="pt", | |
).to("cuda") | |
# a_tok, b_tok = (tok(f" {n}")["input_ids"] for n in [target_new, target_true]) | |
a_tok, b_tok = (tok(f" {n}", add_special_tokens=False)["input_ids"] for n in [target_new, target_true]) | |
choice_a_len, choice_b_len = (len(n) for n in [a_tok, b_tok]) | |
with torch.no_grad(): | |
logits = model(**prompt_tok).logits | |
probs = np.zeros((logits.size(0),), dtype=np.float32) | |
targets_correct = [] | |
for i in range(logits.size(0)): | |
cur_len = choice_a_len if i % 2 == 0 else choice_b_len | |
# additional indices to account for weird tokenizers (like that of gemma) which pads in front instead of back! | |
additional = len(prompt_tok['attention_mask'][i][:torch.where(prompt_tok['attention_mask'][i]==1)[0][0]]) | |
if additional!=0: additional = additional + 1 | |
# Compute suffix probabilities | |
for j in range(cur_len): | |
cur_tok = (a_tok if i % 2 == 0 else b_tok)[j] | |
probs[i] += -torch.nn.functional.log_softmax( | |
logits[i, additional + prefix_lens[i // 2] + j - 1, :], dim=0 | |
)[cur_tok].item() | |
probs[i] /= cur_len | |
# Compute accuracy on new targets | |
if (which_correct[i // 2] == 0 and i % 2 == 0) or ( | |
which_correct[i // 2] == 1 and i % 2 == 1 | |
): | |
correct = True | |
for j in range(cur_len): | |
cur_tok = (a_tok if i % 2 == 0 else b_tok)[j] | |
if logits[i, additional + prefix_lens[i // 2] + j - 1, :].argmax().item() != cur_tok: | |
correct = False | |
break | |
targets_correct.append(correct) | |
return [ | |
{"target_new": probs[i].item(), "target_true": probs[i + 1].item()} | |
for i in range(0, len(probs), 2) | |
], targets_correct | |
def test_generation( | |
model, | |
tok, | |
prefixes: typing.List[str], | |
consistency_texts: typing.List[str], | |
essence_texts: typing.List[str], | |
vec: TfidfVectorizer, | |
): | |
gen_texts = generate_fast( | |
model, | |
tok, | |
prefixes, | |
n_gen_per_prompt=1, | |
max_out_len=100, | |
) | |
ngram_entropy = n_gram_entropy(gen_texts) | |
consistency_tfidf = tfidf_similarity( | |
" ".join(gen_texts), " ".join(consistency_texts), vec | |
) | |
ret = { | |
"ngram_entropy": ngram_entropy, | |
"reference_score": consistency_tfidf, | |
"text": gen_texts, | |
} | |
if len(essence_texts) > 0: | |
ppl = perplexity(model, tok, " ".join(essence_texts), max_input_length=100) | |
ret.update({"essence_score": ppl, "essence_text": essence_texts}) | |
return ret | |
def n_gram_entropy(gen_texts, agg="arith"): | |
assert agg in ["arith", "geom"] | |
return (scipy.stats.mstats.gmean if agg == "geom" else np.mean)( | |
[compute_n_gram_entropy(txt) for txt in gen_texts] | |
).item() | |
def compute_n_gram_entropy(sentence, ns=None, weights=None, agg="arith"): | |
if ns is None: | |
ns = [2, 3] | |
if weights is None: | |
weights = [2 / 3, 4 / 3] | |
assert agg in ["arith", "geom"] | |
entropy_list = [] | |
for n in ns: | |
fdist = compute_freq(sentence, n) | |
freqs = np.array([freq for _, freq in fdist.items()]) | |
freqs = freqs / freqs.sum() | |
entropy_list.append(np.sum(-freqs * np.log(freqs) / np.log(2))) | |
entropy_list = np.array(entropy_list) * np.array(weights) | |
return (scipy.stats.mstats.gmean if agg == "geom" else np.mean)(entropy_list) | |
def compute_freq(sentence, n=2): | |
tokens = nltk.word_tokenize(sentence) | |
ngrams = nltk.ngrams(tokens, n) | |
return nltk.FreqDist(ngrams) | |
def tfidf_similarity(text_a, text_b, vec): | |
encs = vec.transform([text_a, text_b]).A | |
norm = np.linalg.norm | |
return (np.dot(encs[0], encs[1]) / norm(encs[0]) / norm(encs[1])).item() | |