File size: 2,744 Bytes
b16fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import textattack
import transformers
from FlowCorrector import Flow_Corrector
import torch
import torch.nn.functional as F

def count_matching_classes(original, corrected):
    if len(original) != len(corrected):
        raise ValueError("Arrays must have the same length")

    matching_count = 0

    for i in range(len(corrected)):
        if original[i] == corrected[i]:
            matching_count += 1

    return matching_count

if __name__ == "main" :

    # Load model, tokenizer, and model_wrapper
    model = transformers.AutoModelForSequenceClassification.from_pretrained(
        "textattack/bert-base-uncased-ag-news"
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        "textattack/bert-base-uncased-ag-news"
    )
    model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

    # Construct our four components for `Attack`
    from textattack.constraints.pre_transformation import (
        RepeatModification,
        StopwordModification,
    )
    from textattack.constraints.semantics import WordEmbeddingDistance
    from textattack.transformations import WordSwapEmbedding
    from textattack.search_methods import GreedyWordSwapWIR

    goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
    constraints = [
        RepeatModification(),
        StopwordModification(),
        WordEmbeddingDistance(min_cos_sim=0.9),
    ]
    transformation = WordSwapEmbedding(max_candidates=50)
    search_method = GreedyWordSwapWIR(wir_method="weighted-saliency")

    # Construct the actual attack
    attack = textattack.Attack(goal_function, constraints, transformation, search_method)
    attack.cuda_()

    # intialisation de coreecteur
    corrector = Flow_Corrector(
        attack,
        word_rank_file="en_full_ranked.json",
        word_freq_file="en_full_freq.json",
    )

    # All these texts are adverserial ones

    with open('perturbed_texts_ag_news.txt', 'r') as f:
        detected_texts = [line.strip() for line in f]


    #These are orginal texts in same order of adverserial ones 

    with open("original_texts_ag_news.txt", "r") as f:
        original_texts = [line.strip() for line in f]

    victim_model = attack.goal_function.model

    # getting original labels for benchmarking later
    original_classes = [
        torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
        for original_text in original_texts
    ]

    """ 0 :World

        1 : Sports

        2 : Business

        3 : Sci/Tech"""

    corrected_classes = corrector.correct(original_texts)
    print(f"match {count_matching_classes()}")