Spaces:
Sleeping
Sleeping
import io | |
import random | |
from io import BytesIO | |
from typing import List, Tuple | |
import aiohttp | |
import panel as pn | |
import torch | |
from bokeh.themes import Theme | |
# import torchvision.transforms.functional as TVF | |
import torch.nn.functional as F | |
from PIL import Image | |
from transformers import AutoImageProcessor, ResNetForImageClassification | |
from transformers.image_transforms import to_pil_image | |
DEVICE = "cpu" | |
pn.extension("mathjax", design="bootstrap", sizing_mode="stretch_width") | |
def load_processor_model( | |
processor_name: str, model_name: str | |
) -> Tuple[AutoImageProcessor, ResNetForImageClassification]: | |
processor = AutoImageProcessor.from_pretrained(processor_name) | |
model = ResNetForImageClassification.from_pretrained(model_name) | |
return processor, model | |
def denormalize(image, mean, std): | |
mean = torch.tensor(mean).view(1, -1, 1, 1) # Reshape for broadcasting | |
std = torch.tensor(std).view(1, -1, 1, 1) | |
return image * std + mean | |
# FGSM attack code | |
def fgsm_attack(image, epsilon, data_grad): | |
# Collect the element-wise sign of the data gradient | |
sign_data_grad = data_grad.sign() | |
# Create the perturbed image by adjusting each pixel of the input image | |
perturbed_image = image + epsilon * sign_data_grad | |
# Adding clipping to maintain [0,1] range | |
perturbed_image = torch.clamp(perturbed_image, 0, 1) | |
# Return the perturbed image | |
return perturbed_image.detach() | |
def run_forward_backward(image: Image, epsilon): | |
processor, model = load_processor_model( | |
"microsoft/resnet-18", "microsoft/resnet-18" | |
) | |
# Grab input | |
processor.crop_pct = 1 | |
input_tensor = processor(image, return_tensors="pt")["pixel_values"] | |
input_tensor.requires_grad_(True) | |
# Run inference | |
output = model(input_tensor) | |
output = output.logits | |
# Top target | |
top_pred = output.max(1, keepdim=False)[1] | |
# Get NLL loss and backward | |
loss = F.cross_entropy(output, top_pred) | |
model.zero_grad() | |
loss.backward() | |
# Denormalize input | |
mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1) | |
std = torch.tensor(processor.image_std).view(1, -1, 1, 1) | |
input_tensor_denorm = input_tensor.clone().detach() * std + mean | |
# Add noise to input | |
random_noise = torch.sign(torch.randn_like(input_tensor)) * 0.02 | |
input_tensor_denorm_noised = torch.clamp(input_tensor_denorm + random_noise, 0, 1) | |
# input_tensor_denorm_noised = input_tensor_denorm | |
# FGSM attack | |
adv_input_tensor_denorm = fgsm_attack( | |
image=input_tensor_denorm_noised, | |
epsilon=epsilon, | |
data_grad=input_tensor.grad.data, | |
) | |
# Normalize adversarial input tensor back to the input range | |
adv_input_tensor = (adv_input_tensor_denorm - mean) / std | |
# Inference on adversarial image | |
adv_output = model(adv_input_tensor) | |
adv_output = adv_output.logits | |
return ( | |
output, | |
adv_output, | |
input_tensor_denorm.squeeze(), | |
adv_input_tensor_denorm.squeeze(), | |
) | |
async def process_inputs(button_event, image_data: bytes, epsilon: float): | |
""" | |
High level function that takes in the user inputs and returns the | |
classification results as panel objects. | |
""" | |
try: | |
main.disabled = True | |
# if not button_event or (button_event and not isinstance(image_data, bytes)): | |
if not isinstance(image_data, bytes): | |
yield "##### π Upload an image to proceed" | |
return | |
yield "##### β Fetching image and running model..." | |
try: | |
# Open the image using PIL | |
pil_img = Image.open(BytesIO(image_data)) | |
# Run forward + FGSM | |
clean_logits, adv_logits, input_tensor, adv_input_tensor = ( | |
run_forward_backward(image=pil_img, epsilon=epsilon) | |
) | |
except Exception as e: | |
yield f"##### Something went wrong, please try a different image! \n {e}" | |
return | |
img = pn.pane.Image( | |
to_pil_image(input_tensor, do_rescale=True), | |
height=300, | |
align="center", | |
) | |
# Convert image for visualizing | |
adv_img_pil = to_pil_image(adv_input_tensor, do_rescale=True) | |
adv_img = pn.pane.Image( | |
adv_img_pil, | |
height=300, | |
align="center", | |
) | |
# Download image button | |
adv_img_bytes = io.BytesIO() | |
adv_img_pil.save(adv_img_bytes, format="PNG") | |
# download = pn.widgets.FileDownload( | |
# to_pil_image(adv_img_bytes, do_rescale=True), | |
# embed=True, | |
# filename="adv_img.png", | |
# button_type="primary", | |
# button_style="outline", | |
# width_policy="min", | |
# ) | |
# Build the results column | |
k_val = 5 | |
results = pn.Column( | |
pn.Row("###### Uploaded", "###### Adversarial"), | |
pn.Row(img, adv_img), | |
# pn.Row(pn.Spacer(), download), | |
f" ###### Top {k_val} class predictions", | |
) | |
# Get likelihoods | |
likelihoods = [ | |
F.softmax(clean_logits, dim=1).squeeze(), | |
F.softmax(adv_logits, dim=1).squeeze(), | |
] | |
label_bars_rows = pn.Row() | |
for likelihood_tensor in likelihoods: | |
# Get top k values and indices | |
vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) | |
label_bars = pn.Column() | |
for idx, val in zip(idx_topk_clean, vals_topk_clean): | |
prob = val.item() | |
row_label = pn.widgets.StaticText( | |
name=f"{classes[idx]}", value=f"{prob:.2%}", align="center" | |
) | |
row_bar = pn.indicators.Progress( | |
value=int(prob * 100), | |
sizing_mode="stretch_width", | |
bar_color="success" | |
if prob > 0.7 | |
else "warning", # Dynamic color based on value | |
margin=(0, 10), | |
design=pn.theme.Material, | |
) | |
label_bars.append(pn.Column(row_label, row_bar)) | |
# for likelihood_tensor in likelihoods: | |
# # Get top | |
# vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val) | |
# label_bars = pn.Column() | |
# for idx, val in zip(idx_topk_clean, vals_topk_clean): | |
# prob = val.item() | |
# row_label = pn.widgets.StaticText( | |
# name=f"{classes[idx]}", value=f"{prob:.2%}", align="center" | |
# ) | |
# row_bar = pn.indicators.Progress( | |
# value=int(prob * 100), | |
# sizing_mode="stretch_width", | |
# bar_color="secondary", | |
# margin=(0, 10), | |
# design=pn.theme.Material, | |
# ) | |
# label_bars.append(pn.Column(row_label, row_bar)) | |
label_bars_rows.append(label_bars) | |
results.append(label_bars_rows) | |
yield results | |
except Exception as e: | |
yield f"##### Something went wrong! \n {e}" | |
return | |
finally: | |
main.disabled = False | |
#################################################################################################################################### | |
# Get classes | |
classes = [] | |
with open("classes.txt", "r") as file: | |
classes = file.read() | |
classes = classes.split("\n") | |
# Create widgets | |
############################################ | |
# Fil upload widget | |
file_input = pn.widgets.FileInput(name="Upload a PNG image", accept=".png,.jpg") | |
# Epsilon | |
epsilon_slider = pn.widgets.FloatSlider( | |
name=r"$$\epsilon$$ parameter for FGSM", | |
start=0, | |
end=0.1, | |
step=0.005, | |
value=0.000, | |
format="1[.]000", | |
align="center", | |
max_width=500, | |
width_policy="max", | |
) | |
# alpha_slider = pn.widgets.FloatSlider( | |
# name=r"$$\alpha$$ parameter for Gaussian noise", | |
# start=0, | |
# end=0.1, | |
# step=0.005, | |
# value=0.000, | |
# format="1[.]000", | |
# align="center", | |
# max_width=500, | |
# width_policy="max" | |
# ) | |
# Regenerate button | |
regenerate = pn.widgets.Button( | |
name="Regenerate", | |
button_type="primary", | |
width_policy="min", | |
max_width=105, | |
) | |
############################################ | |
# Organize widgets in a column | |
input_widgets = pn.Column( | |
""" | |
###### Classify an image (png/jpeg) with a pre-trained [ResNet18](https://huggingface.co/microsoft/resnet-18) and generate an adversarial example.\n | |
Wondering where the class names come from? Find the list of ImageNet-1K classes [here.](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/) | |
*Please be patient with the application, it is running on a low-resource device.* | |
""", | |
file_input, | |
pn.Row(epsilon_slider, pn.Spacer(width_policy="min", max_width=25), regenerate), | |
) | |
# Add interactivity | |
interactive_result = pn.panel( | |
pn.bind( | |
process_inputs, | |
regenerate, | |
file_input.param.value, | |
epsilon_slider.param.value, | |
), | |
height=600, | |
) | |
footer = pn.pane.Markdown( | |
""" | |
<br><br> | |
If the application is too slow for you, head over to the README to get this running locally. | |
""" | |
) | |
# Create dashboard | |
main = pn.WidgetBox( | |
input_widgets, | |
interactive_result, | |
footer, | |
) | |
title = "Adversarial Sample Generation" | |
pn.template.BootstrapTemplate( | |
title=title, | |
main=main, | |
main_max_width="min(75%, 698px)", | |
header_background="#101820", | |
).servable(title=title) | |
# Functions from original demo | |
# ICON_URLS = { | |
# "brand-github": "https://github.com/holoviz/panel", | |
# "brand-twitter": "https://twitter.com/Panel_Org", | |
# "brand-linkedin": "https://www.linkedin.com/company/panel-org", | |
# "message-circle": "https://discourse.holoviz.org/", | |
# "brand-discord": "https://discord.gg/AXRHnJU6sP", | |
# } | |
# async def random_url(_): | |
# pet = random.choice(["cat", "dog"]) | |
# api_url = f"https://api.the{pet}api.com/v1/images/search" | |
# async with aiohttp.ClientSession() as session: | |
# async with session.get(api_url) as resp: | |
# return (await resp.json())[0]["url"] | |
# @pn.cache | |
# def load_processor_model( | |
# processor_name: str, model_name: str | |
# ) -> Tuple[CLIPProcessor, CLIPModel]: | |
# processor = CLIPProcessor.from_pretrained(processor_name) | |
# model = CLIPModel.from_pretrained(model_name) | |
# return processor, model | |
# async def open_image_url(image_url: str) -> Image: | |
# async with aiohttp.ClientSession() as session: | |
# async with session.get(image_url) as resp: | |
# return Image.open(io.BytesIO(await resp.read())) | |
# def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: | |
# processor, model = load_processor_model( | |
# "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" | |
# ) | |
# inputs = processor( | |
# text=class_items, | |
# images=[image], | |
# return_tensors="pt", # pytorch tensors | |
# ) | |
# print(inputs) | |
# outputs = model(**inputs) | |
# logits_per_image = outputs.logits_per_image | |
# class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() | |
# return class_likelihoods[0] | |