Spaces:
Runtime error
Runtime error
import textattack | |
import transformers | |
import pandas as pd | |
import csv | |
import string | |
import pickle | |
# Construct our four components for `Attack` | |
from textattack.constraints.pre_transformation import ( | |
RepeatModification, | |
StopwordModification, | |
) | |
from textattack.constraints.semantics import WordEmbeddingDistance | |
from textattack.transformations import WordSwapEmbedding | |
from textattack.search_methods import GreedyWordSwapWIR | |
import numpy as np | |
import json | |
import random | |
import re | |
import textattack.shared.attacked_text as atk | |
import torch.nn.functional as F | |
import torch | |
class InvertedText: | |
def __init__( | |
self, | |
swapped_indexes, | |
score, | |
attacked_text, | |
new_class, | |
): | |
self.attacked_text = attacked_text | |
self.swapped_indexes = ( | |
swapped_indexes # dict of swapped indexes with their synonym | |
) | |
self.score = score # value of original class | |
self.new_class = new_class # class after inversion | |
def __repr__(self): | |
return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}" | |
def count_matching_classes(original, corrected, perturbed_texts=None): | |
if len(original) != len(corrected): | |
raise ValueError("Arrays must have the same length") | |
hard_samples = [] | |
easy_samples = [] | |
matching_count = 0 | |
for i in range(len(corrected)): | |
if original[i] == corrected[i]: | |
matching_count += 1 | |
easy_samples.append(perturbed_texts[i]) | |
elif perturbed_texts != None: | |
hard_samples.append(perturbed_texts[i]) | |
return matching_count, hard_samples, easy_samples | |
class Flow_Corrector: | |
def __init__( | |
self, | |
attack, | |
word_rank_file="en_full_ranked.json", | |
word_freq_file="en_full_freq.json", | |
wir_threshold=0.3, | |
): | |
self.attack = attack | |
self.attack.cuda_() | |
self.wir_threshold = wir_threshold | |
with open(word_rank_file, "r") as f: | |
self.word_ranked_frequence = json.load(f) | |
with open(word_freq_file, "r") as f: | |
self.word_frequence = json.load(f) | |
self.victim_model = attack.goal_function.model | |
def wir_gradient( | |
self, | |
attack, | |
victim_model, | |
detected_text, | |
): | |
_, indices_to_order = attack.get_indices_to_order(detected_text) | |
index_scores = np.zeros(len(indices_to_order)) | |
grad_output = victim_model.get_grad(detected_text.tokenizer_input) | |
gradient = grad_output["gradient"] | |
word2token_mapping = detected_text.align_with_model_tokens(victim_model) | |
for i, index in enumerate(indices_to_order): | |
matched_tokens = word2token_mapping[index] | |
if not matched_tokens: | |
index_scores[i] = 0.0 | |
else: | |
agg_grad = np.mean(gradient[matched_tokens], axis=0) | |
index_scores[i] = np.linalg.norm(agg_grad, ord=1) | |
index_order = np.array(indices_to_order)[(-index_scores).argsort()] | |
return index_order | |
def get_syn_freq_dict( | |
self, | |
index_order, | |
detected_text, | |
): | |
most_frequent_syn_dict = {} | |
no_syn = [] | |
freq_thershold = len(self.word_ranked_frequence) / 10 | |
for idx in index_order: | |
# get the synonyms of a specific index | |
try: | |
synonyms = [ | |
attacked_text.words[idx] | |
for attacked_text in self.attack.get_transformations( | |
detected_text, detected_text, indices_to_modify=[idx] | |
) | |
] | |
# getting synonyms that exists in dataset with thiere frequency rank | |
ranked_synonyms = { | |
syn: self.word_ranked_frequence[syn] | |
for syn in synonyms | |
if syn in self.word_ranked_frequence.keys() | |
and self.word_ranked_frequence[syn] < freq_thershold | |
and self.word_ranked_frequence[detected_text.words[idx]] | |
> self.word_ranked_frequence[syn] | |
} | |
# selecting the M most frequent synonym | |
if list(ranked_synonyms.keys()) != []: | |
most_frequent_syn_dict[idx] = list(ranked_synonyms.keys()) | |
except: | |
# no synonyms avaialble in the dataset | |
no_syn.append(idx) | |
return most_frequent_syn_dict | |
def build_candidates( | |
self, detected_text, most_frequent_syn_dict: dict, max_attempt: int | |
): | |
candidates = {} | |
for _ in range(max_attempt): | |
syn_dict = {} | |
current_text = detected_text | |
for index in most_frequent_syn_dict.keys(): | |
syn = random.choice(most_frequent_syn_dict[index]) | |
syn_dict[index] = syn | |
current_text = current_text.replace_word_at_index(index, syn) | |
candidates[current_text] = syn_dict | |
return candidates | |
def find_dominant_class(self, inverted_texts): | |
class_counts = {} # Dictionary to store the count of each new class | |
for text in inverted_texts: | |
new_class = text.new_class | |
class_counts[new_class] = class_counts.get(new_class, 0) + 1 | |
# Find the most dominant class | |
most_dominant_class = max(class_counts, key=class_counts.get) | |
return most_dominant_class | |
def correct(self, detected_texts): | |
corrected_classes = [] | |
for detected_text in detected_texts: | |
# convert to Attacked texts | |
detected_text = atk.AttackedText(detected_text) | |
# getting 30% most important indexes | |
index_order = self.wir_gradient( | |
self.attack, self.victim_model, detected_text | |
) | |
index_order = index_order[: int(len(index_order) * self.wir_threshold)] | |
# getting synonyms according to frequency conditiontions | |
most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text) | |
# generate M candidates | |
candidates = self.build_candidates( | |
detected_text, most_frequent_syn_dict, max_attempt=100 | |
) | |
original_probs = F.softmax(self.victim_model(detected_text.text), dim=1) | |
original_class = torch.argmax(original_probs).item() | |
original_golden_prob = float(original_probs[0][original_class]) | |
nbr_inverted = 0 | |
inverted_texts = [] # a dictionary of inverted texts with | |
bad, impr = 0, 0 | |
dict_deltas = {} | |
batch_inputs = [candidate.text for candidate in candidates.keys()] | |
batch_outputs = self.victim_model(batch_inputs) | |
probabilities = F.softmax(batch_outputs, dim=1) | |
for i, (candidate, syn_dict) in enumerate(candidates.items()): | |
corrected_class = torch.argmax(probabilities[i]).item() | |
new_golden_probability = float(probabilities[i][corrected_class]) | |
if corrected_class != original_class: | |
nbr_inverted += 1 | |
inverted_texts.append( | |
InvertedText( | |
syn_dict, new_golden_probability, candidate, corrected_class | |
) | |
) | |
else: | |
delta = new_golden_probability - original_golden_prob | |
if delta <= 0: | |
bad += 1 | |
else: | |
impr += 1 | |
dict_deltas[candidate] = delta | |
if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / ( | |
len(original_probs[0]) | |
): | |
# selecting the most dominant class | |
dominant_class = self.find_dominant_class(inverted_texts) | |
elif len(inverted_texts) >= len(candidates) / 2: | |
dominant_class = corrected_class | |
else: | |
dominant_class = original_class | |
corrected_classes.append(dominant_class) | |
return corrected_classes | |
def remove_brackets(text): | |
text = text.replace("[[", "") | |
text = text.replace("]]", "") | |
return text | |
def clean_text(text): | |
pattern = "[" + re.escape(string.punctuation) + "]" | |
cleaned_text = re.sub(pattern, " ", text) | |
return cleaned_text | |
# Load model, tokenizer, and model_wrapper | |
model = transformers.AutoModelForSequenceClassification.from_pretrained( | |
"textattack/bert-base-uncased-ag-news" | |
) | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
"textattack/bert-base-uncased-ag-news" | |
) | |
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) | |
constraints = [ | |
RepeatModification(), | |
StopwordModification(), | |
WordEmbeddingDistance(min_cos_sim=0.9), | |
] | |
transformation = WordSwapEmbedding(max_candidates=50) | |
search_method = GreedyWordSwapWIR(wir_method="gradient") | |
# Construct the actual attack | |
attack = textattack.Attack(goal_function, constraints, transformation, search_method) | |
attack.cuda_() | |
results = pd.read_csv("ag_news_results.csv") | |
perturbed_texts = [ | |
results["perturbed_text"][i] | |
for i in range(len(results)) | |
if results["result_type"][i] == "Successful" | |
] | |
original_texts = [ | |
results["original_text"][i] | |
for i in range(len(results)) | |
if results["result_type"][i] == "Successful" | |
] | |
perturbed_texts = [remove_brackets(text) for text in perturbed_texts] | |
original_texts = [remove_brackets(text) for text in original_texts] | |
perturbed_texts = [clean_text(text) for text in perturbed_texts] | |
original_texts = [clean_text(text) for text in original_texts] | |
victim_model = attack.goal_function.model | |
print("Getting corrected classes") | |
print("This may take a while ...") | |
# we can use directly resultds in csv file | |
original_classes = [ | |
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item() | |
for original_text in original_texts | |
] | |
batch_size = 1000 | |
num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size | |
batched_perturbed_texts = [] | |
batched_original_texts = [] | |
batched_original_classes = [] | |
for i in range(num_batches): | |
start = i * batch_size | |
end = min(start + batch_size, len(perturbed_texts)) | |
batched_perturbed_texts.append(perturbed_texts[start:end]) | |
batched_original_texts.append(original_texts[start:end]) | |
batched_original_classes.append(original_classes[start:end]) | |
print(batched_original_classes) | |
hard_samples_list = [] | |
easy_samples_list = [] | |
# Open a CSV file for writing | |
csv_filename = "flow_correction_results_ag_news.csv" | |
with open(csv_filename, "w", newline="") as csvfile: | |
fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"] | |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
# Write the header row | |
writer.writeheader() | |
# Iterate over batched lists | |
batch_num = 0 | |
for perturbed, original, classes in zip( | |
batched_perturbed_texts, batched_original_texts, batched_original_classes | |
): | |
batch_num += 1 | |
print(f"Processing batch number: {batch_num}") | |
for i in range(2): | |
wir_threshold = 0.1 * (i + 1) | |
print(f"Setting Word threshold to: {wir_threshold}") | |
corrector = Flow_Corrector( | |
attack, | |
word_rank_file="en_full_ranked.json", | |
word_freq_file="en_full_freq.json", | |
wir_threshold=wir_threshold, | |
) | |
# Correct perturbed texts | |
print("Correcting perturbed texts...") | |
corrected_perturbed_classes = corrector.correct(perturbed) | |
match_perturbed, hard_samples, easy_samples = count_matching_classes( | |
classes, corrected_perturbed_classes, perturbed | |
) | |
hard_samples_list.extend(hard_samples) | |
easy_samples_list.extend(easy_samples) | |
print(f"Number of matching classes (perturbed): {match_perturbed}") | |
# Correct original texts | |
print("Correcting original texts...") | |
corrected_original_classes = corrector.correct(original) | |
match_original, hard_samples, easy_samples = count_matching_classes( | |
classes, corrected_original_classes, perturbed | |
) | |
print(f"Number of matching classes (original): {match_original}") | |
# Write results to CSV file | |
print("Writing results to CSV file...") | |
writer.writerow( | |
{ | |
"freq_threshold": wir_threshold, | |
"batch_num": batch_num, | |
"match_perturbed": match_perturbed/len(perturbed), | |
"match_original": match_original/len(perturbed), | |
} | |
) | |
print("-" * 20) | |
print("savig samples for more statistics studies") | |
# Save hard_samples_list and easy_samples_list to files | |
with open('hard_samples.pkl', 'wb') as f: | |
pickle.dump(hard_samples_list, f) | |
with open('easy_samples.pkl', 'wb') as f: | |
pickle.dump(easy_samples_list, f) |