Spaces:
Runtime error
Runtime error
File size: 2,744 Bytes
b16fdae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import textattack
import transformers
from FlowCorrector import Flow_Corrector
import torch
import torch.nn.functional as F
def count_matching_classes(original, corrected):
if len(original) != len(corrected):
raise ValueError("Arrays must have the same length")
matching_count = 0
for i in range(len(corrected)):
if original[i] == corrected[i]:
matching_count += 1
return matching_count
if __name__ == "main" :
# Load model, tokenizer, and model_wrapper
model = transformers.AutoModelForSequenceClassification.from_pretrained(
"textattack/bert-base-uncased-ag-news"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"textattack/bert-base-uncased-ag-news"
)
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
# Construct our four components for `Attack`
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.transformations import WordSwapEmbedding
from textattack.search_methods import GreedyWordSwapWIR
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
constraints = [
RepeatModification(),
StopwordModification(),
WordEmbeddingDistance(min_cos_sim=0.9),
]
transformation = WordSwapEmbedding(max_candidates=50)
search_method = GreedyWordSwapWIR(wir_method="weighted-saliency")
# Construct the actual attack
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
attack.cuda_()
# intialisation de coreecteur
corrector = Flow_Corrector(
attack,
word_rank_file="en_full_ranked.json",
word_freq_file="en_full_freq.json",
)
# All these texts are adverserial ones
with open('perturbed_texts_ag_news.txt', 'r') as f:
detected_texts = [line.strip() for line in f]
#These are orginal texts in same order of adverserial ones
with open("original_texts_ag_news.txt", "r") as f:
original_texts = [line.strip() for line in f]
victim_model = attack.goal_function.model
# getting original labels for benchmarking later
original_classes = [
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
for original_text in original_texts
]
""" 0 :World
1 : Sports
2 : Business
3 : Sci/Tech"""
corrected_classes = corrector.correct(original_texts)
print(f"match {count_matching_classes()}")
|