anonymous8
update
d65ddc0
"""
Attacker Class
==============
"""
import collections
import logging
import multiprocessing as mp
import os
import queue
import random
import traceback
import torch
import tqdm
import textattack
from textattack.attack_results import (
FailedAttackResult,
MaximizedAttackResult,
SkippedAttackResult,
SuccessfulAttackResult,
)
from textattack.shared.utils import logger
from .attack import Attack
from .attack_args import AttackArgs
class Attacker:
"""Class for running attacks on a dataset with specified parameters. This
class uses the :class:`~textattack.Attack` to actually run the attacks,
while also providing useful features such as parallel processing,
saving/resuming from a checkpint, logging to files and stdout.
Args:
attack (:class:`~textattack.Attack`):
:class:`~textattack.Attack` used to actually carry out the attack.
dataset (:class:`~textattack.datasets.Dataset`):
Dataset to attack.
attack_args (:class:`~textattack.AttackArgs`):
Arguments for attacking the dataset. For default settings, look at the `AttackArgs` class.
Example::
>>> import textattack
>>> import transformers
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper)
>>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")
>>> # Attack 20 samples with CSV logging and checkpoint saved every 5 interval
>>> attack_args = textattack.AttackArgs(
... num_examples=20,
... log_to_csv="log.csv",
... checkpoint_interval=5,
... checkpoint_dir="checkpoints",
... disable_stdout=True
... )
>>> attacker = textattack.Attacker(attack, dataset, attack_args)
>>> attacker.attack_dataset()
"""
def __init__(self, attack, dataset, attack_args=None):
assert isinstance(
attack, Attack
), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`."
assert isinstance(
dataset, textattack.datasets.Dataset
), f"`dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(dataset)}`."
if attack_args:
assert isinstance(
attack_args, AttackArgs
), f"`attack_args` must be of type `textattack.AttackArgs`, but got type `{type(attack_args)}`."
else:
attack_args = AttackArgs()
self.attack = attack
self.dataset = dataset
self.attack_args = attack_args
self.attack_log_manager = None
# This is to be set if loading from a checkpoint
self._checkpoint = None
def _get_worklist(self, start, end, num_examples, shuffle):
if end - start < num_examples:
logger.warn(
f"Attempting to attack {num_examples} samples when only {end-start} are available."
)
candidates = list(range(start, end))
if shuffle:
random.shuffle(candidates)
worklist = collections.deque(candidates[:num_examples])
candidates = collections.deque(candidates[num_examples:])
assert (len(worklist) + len(candidates)) == (end - start)
return worklist, candidates
def simple_attack(self, text, label):
"""Internal method that carries out attack.
No parallel processing is involved.
"""
if torch.cuda.is_available():
self.attack.cuda_()
example, ground_truth_output = text, label
try:
example = textattack.shared.AttackedText(example)
if self.dataset.label_names is not None:
example.attack_attrs["label_names"] = self.dataset.label_names
try:
result = self.attack.attack(example, ground_truth_output)
except Exception as e:
raise e
# return
if (
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
) or (
not isinstance(result, SuccessfulAttackResult)
and self.attack_args.num_successful_examples
):
return
else:
return result
except KeyboardInterrupt as e:
raise e
def _attack(self):
"""Internal method that carries out attack.
No parallel processing is involved.
"""
if torch.cuda.is_available():
self.attack.cuda_()
if self._checkpoint:
num_remaining_attacks = self._checkpoint.num_remaining_attacks
worklist = self._checkpoint.worklist
worklist_candidates = self._checkpoint.worklist_candidates
logger.info(
f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}."
)
else:
if self.attack_args.num_successful_examples:
num_remaining_attacks = self.attack_args.num_successful_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_successful_examples,
self.attack_args.shuffle,
)
else:
num_remaining_attacks = self.attack_args.num_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_examples,
self.attack_args.shuffle,
)
if not self.attack_args.silent:
print(self.attack, "\n")
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0, dynamic_ncols=True)
if self._checkpoint:
num_results = self._checkpoint.results_count
num_failures = self._checkpoint.num_failed_attacks
num_skipped = self._checkpoint.num_skipped_attacks
num_successes = self._checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_skipped = 0
num_successes = 0
sample_exhaustion_warned = False
while worklist:
idx = worklist.popleft()
try:
example, ground_truth_output = self.dataset[idx]
except IndexError:
continue
example = textattack.shared.AttackedText(example)
if self.dataset.label_names is not None:
example.attack_attrs["label_names"] = self.dataset.label_names
try:
result = self.attack.attack(example, ground_truth_output)
except Exception as e:
raise e
if (
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
) or (
not isinstance(result, SuccessfulAttackResult)
and self.attack_args.num_successful_examples
):
if worklist_candidates:
next_sample = worklist_candidates.popleft()
worklist.append(next_sample)
else:
if not sample_exhaustion_warned:
logger.warn("Ran out of samples to attack!")
sample_exhaustion_warned = True
else:
pbar.update(1)
self.attack_log_manager.log_result(result)
if not self.attack_args.disable_stdout and not self.attack_args.silent:
print("\n")
num_results += 1
if isinstance(result, SkippedAttackResult):
num_skipped += 1
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)):
num_successes += 1
if isinstance(result, FailedAttackResult):
num_failures += 1
pbar.set_description(
f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}"
)
if (
self.attack_args.checkpoint_interval
and len(self.attack_log_manager.results)
% self.attack_args.checkpoint_interval
== 0
):
new_checkpoint = textattack.shared.AttackCheckpoint(
self.attack_args,
self.attack_log_manager,
worklist,
worklist_candidates,
)
new_checkpoint.save()
self.attack_log_manager.flush()
pbar.close()
print()
# Enable summary stdout
if not self.attack_args.silent and self.attack_args.disable_stdout:
self.attack_log_manager.enable_stdout()
if self.attack_args.enable_advance_metrics:
self.attack_log_manager.enable_advance_metrics = True
self.attack_log_manager.log_summary()
self.attack_log_manager.flush()
print()
def _attack_parallel(self):
pytorch_multiprocessing_workaround()
if self._checkpoint:
num_remaining_attacks = self._checkpoint.num_remaining_attacks
worklist = self._checkpoint.worklist
worklist_candidates = self._checkpoint.worklist_candidates
logger.info(
f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}."
)
else:
if self.attack_args.num_successful_examples:
num_remaining_attacks = self.attack_args.num_successful_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_successful_examples,
self.attack_args.shuffle,
)
else:
num_remaining_attacks = self.attack_args.num_examples
# We make `worklist` deque (linked-list) for easy pop and append.
# Candidates are other samples we can attack if we need more samples.
worklist, worklist_candidates = self._get_worklist(
self.attack_args.num_examples_offset,
len(self.dataset),
self.attack_args.num_examples,
self.attack_args.shuffle,
)
in_queue = torch.multiprocessing.Queue()
out_queue = torch.multiprocessing.Queue()
for i in worklist:
try:
example, ground_truth_output = self.dataset[i]
example = textattack.shared.AttackedText(example)
if self.dataset.label_names is not None:
example.attack_attrs["label_names"] = self.dataset.label_names
in_queue.put((i, example, ground_truth_output))
except IndexError:
raise IndexError(
f"Tried to access element at {i} in dataset of size {len(self.dataset)}."
)
# We reserve the first GPU for coordinating workers.
num_gpus = torch.cuda.device_count()
num_workers = self.attack_args.num_workers_per_device * num_gpus
logger.info(f"Running {num_workers} worker(s) on {num_gpus} GPU(s).")
# Lock for synchronization
lock = mp.Lock()
# We move Attacker (and its components) to CPU b/c we don't want models using wrong GPU in worker processes.
self.attack.cpu_()
torch.cuda.empty_cache()
# Start workers.
worker_pool = torch.multiprocessing.Pool(
num_workers,
attack_from_queue,
(
self.attack,
self.attack_args,
num_gpus,
mp.Value("i", 1, lock=False),
lock,
in_queue,
out_queue,
),
)
# Log results asynchronously and update progress bar.
if self._checkpoint:
num_results = self._checkpoint.results_count
num_failures = self._checkpoint.num_failed_attacks
num_skipped = self._checkpoint.num_skipped_attacks
num_successes = self._checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_skipped = 0
num_successes = 0
logger.info(f"Worklist size: {len(worklist)}")
logger.info(f"Worklist candidate size: {len(worklist_candidates)}")
sample_exhaustion_warned = False
pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0, dynamic_ncols=True)
while worklist:
idx, result = out_queue.get(block=True)
worklist.remove(idx)
if isinstance(result, tuple) and isinstance(result[0], Exception):
logger.error(
f'Exception encountered for input "{self.dataset[idx][0]}".'
)
error_trace = result[1]
logger.error(error_trace)
in_queue.close()
in_queue.join_thread()
out_queue.close()
out_queue.join_thread()
worker_pool.terminate()
worker_pool.join()
return
elif (
isinstance(result, SkippedAttackResult) and self.attack_args.attack_n
) or (
not isinstance(result, SuccessfulAttackResult)
and self.attack_args.num_successful_examples
):
if worklist_candidates:
next_sample = worklist_candidates.popleft()
example, ground_truth_output = self.dataset[next_sample]
example = textattack.shared.AttackedText(example)
if self.dataset.label_names is not None:
example.attack_attrs["label_names"] = self.dataset.label_names
worklist.append(next_sample)
in_queue.put((next_sample, example, ground_truth_output))
else:
if not sample_exhaustion_warned:
logger.warn("Ran out of samples to attack!")
sample_exhaustion_warned = True
else:
pbar.update()
self.attack_log_manager.log_result(result)
num_results += 1
if isinstance(result, SkippedAttackResult):
num_skipped += 1
if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)):
num_successes += 1
if isinstance(result, FailedAttackResult):
num_failures += 1
pbar.set_description(
f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}"
)
if (
self.attack_args.checkpoint_interval
and len(self.attack_log_manager.results)
% self.attack_args.checkpoint_interval
== 0
):
new_checkpoint = textattack.shared.AttackCheckpoint(
self.attack_args,
self.attack_log_manager,
worklist,
worklist_candidates,
)
new_checkpoint.save()
self.attack_log_manager.flush()
# Send sentinel values to worker processes
for _ in range(num_workers):
in_queue.put(("END", "END", "END"))
worker_pool.close()
worker_pool.join()
pbar.close()
print()
# Enable summary stdout.
if not self.attack_args.silent and self.attack_args.disable_stdout:
self.attack_log_manager.enable_stdout()
if self.attack_args.enable_advance_metrics:
self.attack_log_manager.enable_advance_metrics = True
self.attack_log_manager.log_summary()
self.attack_log_manager.flush()
print()
def attack_dataset(self):
"""Attack the dataset.
Returns:
:obj:`list[AttackResult]` - List of :class:`~textattack.attack_results.AttackResult` obtained after attacking the given dataset..
"""
if self.attack_args.silent:
logger.setLevel(logging.ERROR)
if self.attack_args.query_budget:
self.attack.goal_function.query_budget = self.attack_args.query_budget
if not self.attack_log_manager:
self.attack_log_manager = AttackArgs.create_loggers_from_args(
self.attack_args
)
textattack.shared.utils.set_seed(self.attack_args.random_seed)
if self.dataset.shuffled and self.attack_args.checkpoint_interval:
# Not allowed b/c we cannot recover order of shuffled data
raise ValueError(
"Cannot use `--checkpoint-interval` with dataset that has been internally shuffled."
)
self.attack_args.num_examples = (
len(self.dataset)
if self.attack_args.num_examples == -1
else self.attack_args.num_examples
)
if self.attack_args.parallel:
if torch.cuda.device_count() == 0:
raise Exception(
"Found no GPU on your system. To run attacks in parallel, GPU is required."
)
self._attack_parallel()
else:
self._attack()
if self.attack_args.silent:
logger.setLevel(logging.INFO)
return self.attack_log_manager.results
def update_attack_args(self, **kwargs):
"""To update any attack args, pass the new argument as keyword argument
to this function.
Examples::
>>> attacker = #some instance of Attacker
>>> # To switch to parallel mode and increase checkpoint interval from 100 to 500
>>> attacker.update_attack_args(parallel=True, checkpoint_interval=500)
"""
for k in kwargs:
if hasattr(self.attack_args, k):
self.attack_args.k = kwargs[k]
else:
raise ValueError(f"`textattack.AttackArgs` does not have field {k}.")
@classmethod
def from_checkpoint(cls, attack, dataset, checkpoint):
"""Resume attacking from a saved checkpoint. Attacker and dataset must
be recovered by the user again, while attack args are loaded from the
saved checkpoint.
Args:
attack (:class:`~textattack.Attack`):
Attack object for carrying out the attack.
dataset (:class:`~textattack.datasets.Dataset`):
Dataset to attack.
checkpoint (:obj:`Union[str, :class:`~textattack.shared.AttackChecpoint`]`):
Path of saved checkpoint or the actual saved checkpoint.
"""
assert isinstance(
checkpoint, (str, textattack.shared.AttackCheckpoint)
), f"`checkpoint` must be of type `str` or `textattack.shared.AttackCheckpoint`, but got type `{type(checkpoint)}`."
if isinstance(checkpoint, str):
checkpoint = textattack.shared.AttackCheckpoint.load(checkpoint)
attacker = cls(attack, dataset, checkpoint.attack_args)
attacker.attack_log_manager = checkpoint.attack_log_manager
attacker._checkpoint = checkpoint
return attacker
@staticmethod
def attack_interactive(attack):
print(attack, "\n")
print("Running in interactive mode")
print("----------------------------")
while True:
print('Enter a sentence to attack or "q" to quit:')
text = input()
if text == "q":
break
if not text:
continue
print("Attacking...")
example = textattack.shared.attacked_text.AttackedText(text)
output = attack.goal_function.get_output(example)
result = attack.attack(example, output)
print(result.__str__(color_method="ansi") + "\n")
#
# Helper Methods for multiprocess attacks
#
def pytorch_multiprocessing_workaround():
# This is a fix for a known bug
try:
torch.multiprocessing.set_start_method("spawn", force=True)
torch.multiprocessing.set_sharing_strategy("file_system")
except RuntimeError:
pass
def set_env_variables(gpu_id):
# Disable tensorflow logs, except in the case of an error.
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Set sharing strategy to file_system to avoid file descriptor leaks
torch.multiprocessing.set_sharing_strategy("file_system")
# Only use one GPU, if we have one.
# For Tensorflow
# TODO: Using USE with `--parallel` raises similar issue as https://github.com/tensorflow/tensorflow/issues/38518#
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# For PyTorch
torch.cuda.set_device(gpu_id)
# Fix TensorFlow GPU memory growth
try:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
gpu = gpus[gpu_id]
tf.config.experimental.set_visible_devices(gpu, "GPU")
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
except ModuleNotFoundError:
pass
def attack_from_queue(
attack, attack_args, num_gpus, first_to_start, lock, in_queue, out_queue
):
assert isinstance(
attack, Attack
), f"`attack` must be of type `Attack`, but got type `{type(attack)}`."
gpu_id = (torch.multiprocessing.current_process()._identity[0] - 1) % num_gpus
set_env_variables(gpu_id)
textattack.shared.utils.set_seed(attack_args.random_seed)
if torch.multiprocessing.current_process()._identity[0] > 1:
logging.disable()
attack.cuda_()
# Simple non-synchronized check to see if it's the first process to reach this point.
# This let us avoid waiting for lock.
if bool(first_to_start.value):
# If it's first process to reach this step, we first try to acquire the lock to update the value.
with lock:
# Because another process could have changed `first_to_start=False` while we wait, we check again.
if bool(first_to_start.value):
first_to_start.value = 0
if not attack_args.silent:
print(attack, "\n")
while True:
try:
i, example, ground_truth_output = in_queue.get(timeout=5)
if i == "END" and example == "END" and ground_truth_output == "END":
# End process when sentinel value is received
break
else:
result = attack.attack(example, ground_truth_output)
out_queue.put((i, result))
except Exception as e:
if isinstance(e, queue.Empty):
continue
else:
out_queue.put((i, (e, traceback.format_exc())))