Rapid-Textual-Adversarial-Defense
/
textattack
/constraints
/pre_transformation
/input_column_modification.py
""" | |
Input Column Modification | |
-------------------------- | |
""" | |
from textattack.constraints import PreTransformationConstraint | |
class InputColumnModification(PreTransformationConstraint): | |
"""A constraint disallowing the modification of words within a specific | |
input column. | |
For example, can prevent modification of 'premise' during | |
entailment. | |
""" | |
def __init__(self, matching_column_labels, columns_to_ignore): | |
self.matching_column_labels = matching_column_labels | |
self.columns_to_ignore = columns_to_ignore | |
def _get_modifiable_indices(self, current_text): | |
"""Returns the word indices in current_text which are able to be | |
deleted. | |
If ``current_text.column_labels`` doesn't match | |
``self.matching_column_labels``, do nothing, and allow all words | |
to be modified. | |
If it does match, only allow words to be modified if they are not | |
in columns from ``columns_to_ignore``. | |
""" | |
if current_text.column_labels != self.matching_column_labels: | |
return set(range(len(current_text.words))) | |
idx = 0 | |
indices_to_modify = set() | |
for column, words in zip( | |
current_text.column_labels, current_text.words_per_input | |
): | |
num_words = len(words) | |
if column not in self.columns_to_ignore: | |
indices_to_modify |= set(range(idx, idx + num_words)) | |
idx += num_words | |
return indices_to_modify | |
def extra_repr_keys(self): | |
return ["matching_column_labels", "columns_to_ignore"] | |