import io
import random
from io import BytesIO
from typing import List, Tuple
import aiohttp
import panel as pn
import torch
from bokeh.themes import Theme
# import torchvision.transforms.functional as TVF
import torch.nn.functional as F
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification
from transformers.image_transforms import to_pil_image
DEVICE = "cpu"
pn.extension("mathjax", design="bootstrap", sizing_mode="stretch_width")
@pn.cache
def load_processor_model(
processor_name: str, model_name: str
) -> Tuple[AutoImageProcessor, ResNetForImageClassification]:
processor = AutoImageProcessor.from_pretrained(processor_name)
model = ResNetForImageClassification.from_pretrained(model_name)
return processor, model
def denormalize(image, mean, std):
mean = torch.tensor(mean).view(1, -1, 1, 1) # Reshape for broadcasting
std = torch.tensor(std).view(1, -1, 1, 1)
return image * std + mean
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
# Collect the element-wise sign of the data gradient
sign_data_grad = data_grad.sign()
# Create the perturbed image by adjusting each pixel of the input image
perturbed_image = image + epsilon * sign_data_grad
# Adding clipping to maintain [0,1] range
perturbed_image = torch.clamp(perturbed_image, 0, 1)
# Return the perturbed image
return perturbed_image.detach()
def run_forward_backward(image: Image, epsilon):
processor, model = load_processor_model(
"microsoft/resnet-18", "microsoft/resnet-18"
)
# Grab input
processor.crop_pct = 1
input_tensor = processor(image, return_tensors="pt")["pixel_values"]
input_tensor.requires_grad_(True)
# Run inference
output = model(input_tensor)
output = output.logits
# Top target
top_pred = output.max(1, keepdim=False)[1]
# Get NLL loss and backward
loss = F.cross_entropy(output, top_pred)
model.zero_grad()
loss.backward()
# Denormalize input
mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1)
std = torch.tensor(processor.image_std).view(1, -1, 1, 1)
input_tensor_denorm = input_tensor.clone().detach() * std + mean
# Add noise to input
random_noise = torch.sign(torch.randn_like(input_tensor)) * 0.02
input_tensor_denorm_noised = torch.clamp(input_tensor_denorm + random_noise, 0, 1)
# input_tensor_denorm_noised = input_tensor_denorm
# FGSM attack
adv_input_tensor_denorm = fgsm_attack(
image=input_tensor_denorm_noised,
epsilon=epsilon,
data_grad=input_tensor.grad.data,
)
# Normalize adversarial input tensor back to the input range
adv_input_tensor = (adv_input_tensor_denorm - mean) / std
# Inference on adversarial image
adv_output = model(adv_input_tensor)
adv_output = adv_output.logits
return (
output,
adv_output,
input_tensor_denorm.squeeze(),
adv_input_tensor_denorm.squeeze(),
)
async def process_inputs(button_event, image_data: bytes, epsilon: float):
"""
High level function that takes in the user inputs and returns the
classification results as panel objects.
"""
try:
main.disabled = True
# if not button_event or (button_event and not isinstance(image_data, bytes)):
if not isinstance(image_data, bytes):
yield "##### 👋 Upload an image to proceed"
return
yield "##### âš™ Fetching image and running model..."
try:
# Open the image using PIL
pil_img = Image.open(BytesIO(image_data))
# Run forward + FGSM
clean_logits, adv_logits, input_tensor, adv_input_tensor = (
run_forward_backward(image=pil_img, epsilon=epsilon)
)
except Exception as e:
yield f"##### Something went wrong, please try a different image! \n {e}"
return
img = pn.pane.Image(
to_pil_image(input_tensor, do_rescale=True),
height=300,
align="center",
)
# Convert image for visualizing
adv_img_pil = to_pil_image(adv_input_tensor, do_rescale=True)
adv_img = pn.pane.Image(
adv_img_pil,
height=300,
align="center",
)
# Download image button
adv_img_bytes = io.BytesIO()
adv_img_pil.save(adv_img_bytes, format="PNG")
# download = pn.widgets.FileDownload(
# to_pil_image(adv_img_bytes, do_rescale=True),
# embed=True,
# filename="adv_img.png",
# button_type="primary",
# button_style="outline",
# width_policy="min",
# )
# Build the results column
k_val = 5
results = pn.Column(
pn.Row("###### Uploaded", "###### Adversarial"),
pn.Row(img, adv_img),
# pn.Row(pn.Spacer(), download),
f" ###### Top {k_val} class predictions",
)
# Get likelihoods
likelihoods = [
F.softmax(clean_logits, dim=1).squeeze(),
F.softmax(adv_logits, dim=1).squeeze(),
]
label_bars_rows = pn.Row()
for likelihood_tensor in likelihoods:
# Get top k values and indices
vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val)
label_bars = pn.Column()
for idx, val in zip(idx_topk_clean, vals_topk_clean):
prob = val.item()
row_label = pn.widgets.StaticText(
name=f"{classes[idx]}", value=f"{prob:.2%}", align="center"
)
row_bar = pn.indicators.Progress(
value=int(prob * 100),
sizing_mode="stretch_width",
bar_color="success"
if prob > 0.7
else "warning", # Dynamic color based on value
margin=(0, 10),
design=pn.theme.Material,
)
label_bars.append(pn.Column(row_label, row_bar))
# for likelihood_tensor in likelihoods:
# # Get top
# vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val)
# label_bars = pn.Column()
# for idx, val in zip(idx_topk_clean, vals_topk_clean):
# prob = val.item()
# row_label = pn.widgets.StaticText(
# name=f"{classes[idx]}", value=f"{prob:.2%}", align="center"
# )
# row_bar = pn.indicators.Progress(
# value=int(prob * 100),
# sizing_mode="stretch_width",
# bar_color="secondary",
# margin=(0, 10),
# design=pn.theme.Material,
# )
# label_bars.append(pn.Column(row_label, row_bar))
label_bars_rows.append(label_bars)
results.append(label_bars_rows)
yield results
except Exception as e:
yield f"##### Something went wrong! \n {e}"
return
finally:
main.disabled = False
####################################################################################################################################
# Get classes
classes = []
with open("classes.txt", "r") as file:
classes = file.read()
classes = classes.split("\n")
# Create widgets
############################################
# Fil upload widget
file_input = pn.widgets.FileInput(name="Upload a PNG image", accept=".png,.jpg")
# Epsilon
epsilon_slider = pn.widgets.FloatSlider(
name=r"$$\epsilon$$ parameter for FGSM",
start=0,
end=0.1,
step=0.005,
value=0.005,
format="1[.]000",
align="center",
max_width=500,
width_policy="max",
)
# alpha_slider = pn.widgets.FloatSlider(
# name=r"$$\alpha$$ parameter for Gaussian noise",
# start=0,
# end=0.1,
# step=0.005,
# value=0.000,
# format="1[.]000",
# align="center",
# max_width=500,
# width_policy="max"
# )
# Regenerate button
regenerate = pn.widgets.Button(
name="Regenerate",
button_type="primary",
width_policy="min",
max_width=105,
)
############################################
# Organize widgets in a column
input_widgets = pn.Column(
"""
###### Classify an image (png/jpeg) with a pre-trained [ResNet18](https://huggingface.co/microsoft/resnet-18) and generate an adversarial example.\n
Wondering where the class names come from? Find the list of ImageNet-1K classes [here.](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/)
*Please be patient with the application, it is running on a low-resource device.*
""",
file_input,
pn.Row(epsilon_slider, pn.Spacer(width_policy="min", max_width=25), regenerate),
)
# Add interactivity
interactive_result = pn.panel(
pn.bind(
process_inputs,
regenerate,
file_input.param.value,
epsilon_slider.param.value,
),
height=600,
)
footer = pn.pane.Markdown(
"""
If the application is too slow for you, head over to the README to get this running locally.
"""
)
# Create dashboard
main = pn.WidgetBox(
input_widgets,
interactive_result,
footer,
)
title = "Adversarial Sample Generation"
pn.template.BootstrapTemplate(
title=title,
main=main,
main_max_width="min(75%, 698px)",
header_background="#101820",
).servable(title=title)
# Functions from original demo
# ICON_URLS = {
# "brand-github": "https://github.com/holoviz/panel",
# "brand-twitter": "https://twitter.com/Panel_Org",
# "brand-linkedin": "https://www.linkedin.com/company/panel-org",
# "message-circle": "https://discourse.holoviz.org/",
# "brand-discord": "https://discord.gg/AXRHnJU6sP",
# }
# async def random_url(_):
# pet = random.choice(["cat", "dog"])
# api_url = f"https://api.the{pet}api.com/v1/images/search"
# async with aiohttp.ClientSession() as session:
# async with session.get(api_url) as resp:
# return (await resp.json())[0]["url"]
# @pn.cache
# def load_processor_model(
# processor_name: str, model_name: str
# ) -> Tuple[CLIPProcessor, CLIPModel]:
# processor = CLIPProcessor.from_pretrained(processor_name)
# model = CLIPModel.from_pretrained(model_name)
# return processor, model
# async def open_image_url(image_url: str) -> Image:
# async with aiohttp.ClientSession() as session:
# async with session.get(image_url) as resp:
# return Image.open(io.BytesIO(await resp.read()))
# def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
# processor, model = load_processor_model(
# "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
# )
# inputs = processor(
# text=class_items,
# images=[image],
# return_tensors="pt", # pytorch tensors
# )
# print(inputs)
# outputs = model(**inputs)
# logits_per_image = outputs.logits_per_image
# class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
# return class_likelihoods[0]