from copy import deepcopy import torch from torch import nn from torch.cuda import amp from tqdm import tqdm from .base import Attacker, Empty class PGD(Attacker): def __init__(self, model, img_transform=(lambda x: x, lambda x: x), use_amp=False): super().__init__(model, img_transform) self.use_amp = use_amp self.call_back = None self.img_loader = None self.img_hook = None self.scaler = amp.GradScaler(enabled=use_amp) def set_para(self, eps=8, alpha=lambda: 8, iters=20, **kwargs): super().set_para(eps=eps, alpha=alpha, iters=iters, **kwargs) def set_call_back(self, call_back): self.call_back = call_back def set_img_loader(self, img_loader): self.img_loader = img_loader def step(self, images, labels, loss): with amp.autocast(enabled=self.use_amp): images.requires_grad = True outputs = self.model(images).logits self.model.zero_grad() cost = loss(outputs, labels) # +outputs[2].view(-1)[0]*0+outputs[1].view(-1)[0]*0+outputs[0].view(-1)[0]*0 #support DDP self.scaler.scale(cost).backward() adv_images = (images + self.alpha() * images.grad.sign()).detach_() eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps) images = self.img_transform[0]( torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=1).detach_()) return images def set_data(self, images, labels): self.ori_images = deepcopy(images) self.images = images self.labels = labels def __iter__(self): self.atk_step = 0 return self def __next__(self): self.atk_step += 1 if self.atk_step > self.iters: raise StopIteration with self.model.no_sync() if isinstance(self.model, nn.parallel.DistributedDataParallel) else Empty(): self.model.eval() self.images = self.forward(self, self.images, self.labels) self.model.zero_grad() self.model.train() return self.ori_images, self.images.detach(), self.labels def attack(self, images, labels, step_func=None): # images = deepcopy(images) self.ori_images = deepcopy(images) for i in tqdm(range(self.iters)): self.model.eval() images = self.forward(self, images, labels) self.model.zero_grad() self.model.train() if self.call_back: self.call_back(self.ori_images, images.detach(), labels) if self.img_hook is not None: images = self.img_hook(self.ori_images, images.detach()) if step_func: step_func(i + 1) return images