root commited on
Commit
0663fac
β€’
1 Parent(s): 55150dc

add random noise

Browse files
Files changed (2) hide show
  1. README.md +14 -2
  2. app.py +99 -45
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Fgsm Project
3
  emoji: πŸ“ˆ
4
  colorFrom: gray
5
  colorTo: green
@@ -7,4 +7,16 @@ sdk: docker
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: FGSM Project
3
  emoji: πŸ“ˆ
4
  colorFrom: gray
5
  colorTo: green
 
7
  pinned: false
8
  ---
9
 
10
+ This repository was developed inside a [devcontainer](https://containers.dev/).
11
+
12
+ If you are after speed, you can run this application locally.
13
+
14
+ 1. Clone the repository
15
+
16
+ `git clone https://huggingface.co/spaces/niniack/fgsm-project`
17
+
18
+ 2. Open up the project inside a devcontainer. Check [this](https://code.visualstudio.com/docs/devcontainers/containers) for instructions with VS Code.
19
+
20
+ 3. Start the application
21
+
22
+ `panel serve /path/to/app.py/ --dev`
app.py CHANGED
@@ -52,6 +52,7 @@ def run_forward_backward(image: Image, epsilon):
52
  )
53
 
54
  # Grab input
 
55
  input_tensor = processor(image, return_tensors="pt")["pixel_values"]
56
  input_tensor.requires_grad_(True)
57
 
@@ -70,11 +71,18 @@ def run_forward_backward(image: Image, epsilon):
70
  # Denormalize input
71
  mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1)
72
  std = torch.tensor(processor.image_std).view(1, -1, 1, 1)
73
- input_tensor_denorm = input_tensor.detach() * std + mean
 
 
 
 
 
74
 
75
  # FGSM attack
76
  adv_input_tensor_denorm = fgsm_attack(
77
- image=input_tensor_denorm, epsilon=epsilon, data_grad=input_tensor.grad.data
 
 
78
  )
79
 
80
  # Normalize adversarial input tensor back to the input range
@@ -84,7 +92,6 @@ def run_forward_backward(image: Image, epsilon):
84
  adv_output = model(adv_input_tensor)
85
  adv_output = adv_output.logits
86
 
87
-
88
  return (
89
  output,
90
  adv_output,
@@ -109,10 +116,10 @@ async def process_inputs(button_event, image_data: bytes, epsilon: float):
109
  try:
110
  # Open the image using PIL
111
  pil_img = Image.open(BytesIO(image_data))
112
-
113
  # Run forward + FGSM
114
- clean_logits, adv_logits, input_tensor, adv_input_tensor = run_forward_backward(
115
- image=pil_img, epsilon=epsilon
116
  )
117
 
118
  except Exception as e:
@@ -121,21 +128,37 @@ async def process_inputs(button_event, image_data: bytes, epsilon: float):
121
 
122
  img = pn.pane.Image(
123
  to_pil_image(input_tensor, do_rescale=True),
124
- height=350,
125
  align="center",
126
  )
127
 
128
  # Convert image for visualizing
 
129
  adv_img = pn.pane.Image(
130
- to_pil_image(adv_input_tensor, do_rescale=True),
131
- height=350,
132
  align="center",
133
  )
134
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  # Build the results column
136
  k_val = 5
137
  results = pn.Column(
138
- pn.Row("###### Uploaded", "###### Adversarial"), pn.Row(img, adv_img), f" ###### Top {k_val} class predictions",
 
 
 
139
  )
140
 
141
  # Get likelihoods
@@ -150,43 +173,43 @@ async def process_inputs(button_event, image_data: bytes, epsilon: float):
150
  # Get top k values and indices
151
  vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val)
152
  label_bars = pn.Column()
153
-
154
  for idx, val in zip(idx_topk_clean, vals_topk_clean):
155
  prob = val.item()
156
  row_label = pn.widgets.StaticText(
157
- name=f"{classes[idx]}",
158
- value=f"{prob:.2%}",
159
- align="center"
160
  )
161
  row_bar = pn.indicators.Progress(
162
  value=int(prob * 100),
163
  sizing_mode="stretch_width",
164
- bar_color="success" if prob > 0.7 else "warning", # Dynamic color based on value
 
 
165
  margin=(0, 10),
166
  design=pn.theme.Material,
167
  )
168
  label_bars.append(pn.Column(row_label, row_bar))
169
 
170
- # for likelihood_tensor in likelihoods:
171
- # # Get top
172
- # vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val)
173
- # label_bars = pn.Column()
174
- # for idx, val in zip(idx_topk_clean, vals_topk_clean):
175
- # prob = val.item()
176
- # row_label = pn.widgets.StaticText(
177
- # name=f"{classes[idx]}", value=f"{prob:.2%}", align="center"
178
- # )
179
- # row_bar = pn.indicators.Progress(
180
- # value=int(prob * 100),
181
- # sizing_mode="stretch_width",
182
- # bar_color="secondary",
183
- # margin=(0, 10),
184
- # design=pn.theme.Material,
185
- # )
186
- # label_bars.append(pn.Column(row_label, row_bar))
187
 
188
  label_bars_rows.append(label_bars)
189
-
190
  results.append(label_bars_rows)
191
 
192
  yield results
@@ -194,7 +217,7 @@ async def process_inputs(button_event, image_data: bytes, epsilon: float):
194
  except Exception as e:
195
  yield f"##### Something went wrong! \n {e}"
196
  return
197
-
198
  finally:
199
  main.disabled = False
200
 
@@ -213,37 +236,68 @@ file_input = pn.widgets.FileInput(name="Upload a PNG image", accept=".png,.jpg")
213
 
214
  # Epsilon
215
  epsilon_slider = pn.widgets.FloatSlider(
216
- name=r"$$\epsilon$$", start=0, end=0.1, step=0.005, value=0.05, format='1[.]000'
 
 
 
 
 
 
 
 
217
  )
218
 
219
- # Upload button widget
220
- upload_image = pn.widgets.Button(name="Upload image", align="end")
221
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  ############################################
223
 
224
  # Organize widgets in a column
225
  input_widgets = pn.Column(
226
  """
227
- ###### Classify an image with a pre-trained [ResNet18](https://huggingface.co/microsoft/resnet-18) and generate an adversarial example.\n
 
 
228
 
229
- Please be patient with the application, it is running on a low-resource device.
230
  """,
231
  file_input,
232
- epsilon_slider,
233
  )
234
 
235
  # Add interactivity
236
  interactive_result = pn.panel(
237
  pn.bind(
238
- process_inputs, upload_image, file_input.param.value, epsilon_slider.param.value
 
 
 
239
  ),
240
  height=600,
241
  )
242
 
243
  footer = pn.pane.Markdown(
244
  """
245
- <br><br><br><br>
246
- Wondering where the class names come from? Find the full list [here](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/)
 
247
  """
248
  )
249
 
@@ -259,7 +313,7 @@ title = "Adversarial Sample Generation"
259
  pn.template.BootstrapTemplate(
260
  title=title,
261
  main=main,
262
- main_max_width="min(50%, 698px)",
263
  header_background="#101820",
264
  ).servable(title=title)
265
 
 
52
  )
53
 
54
  # Grab input
55
+ processor.crop_pct = 1
56
  input_tensor = processor(image, return_tensors="pt")["pixel_values"]
57
  input_tensor.requires_grad_(True)
58
 
 
71
  # Denormalize input
72
  mean = torch.tensor(processor.image_mean).view(1, -1, 1, 1)
73
  std = torch.tensor(processor.image_std).view(1, -1, 1, 1)
74
+ input_tensor_denorm = input_tensor.clone().detach() * std + mean
75
+
76
+ # Add noise to input
77
+ random_noise = torch.sign(torch.randn_like(input_tensor)) * 0.02
78
+ input_tensor_denorm_noised = torch.clamp(input_tensor_denorm + random_noise, 0, 1)
79
+ # input_tensor_denorm_noised = input_tensor_denorm
80
 
81
  # FGSM attack
82
  adv_input_tensor_denorm = fgsm_attack(
83
+ image=input_tensor_denorm_noised,
84
+ epsilon=epsilon,
85
+ data_grad=input_tensor.grad.data,
86
  )
87
 
88
  # Normalize adversarial input tensor back to the input range
 
92
  adv_output = model(adv_input_tensor)
93
  adv_output = adv_output.logits
94
 
 
95
  return (
96
  output,
97
  adv_output,
 
116
  try:
117
  # Open the image using PIL
118
  pil_img = Image.open(BytesIO(image_data))
119
+
120
  # Run forward + FGSM
121
+ clean_logits, adv_logits, input_tensor, adv_input_tensor = (
122
+ run_forward_backward(image=pil_img, epsilon=epsilon)
123
  )
124
 
125
  except Exception as e:
 
128
 
129
  img = pn.pane.Image(
130
  to_pil_image(input_tensor, do_rescale=True),
131
+ height=300,
132
  align="center",
133
  )
134
 
135
  # Convert image for visualizing
136
+ adv_img_pil = to_pil_image(adv_input_tensor, do_rescale=True)
137
  adv_img = pn.pane.Image(
138
+ adv_img_pil,
139
+ height=300,
140
  align="center",
141
  )
142
 
143
+ # Download image button
144
+ adv_img_bytes = io.BytesIO()
145
+ adv_img_pil.save(adv_img_bytes, format="PNG")
146
+ # download = pn.widgets.FileDownload(
147
+ # to_pil_image(adv_img_bytes, do_rescale=True),
148
+ # embed=True,
149
+ # filename="adv_img.png",
150
+ # button_type="primary",
151
+ # button_style="outline",
152
+ # width_policy="min",
153
+ # )
154
+
155
  # Build the results column
156
  k_val = 5
157
  results = pn.Column(
158
+ pn.Row("###### Uploaded", "###### Adversarial"),
159
+ pn.Row(img, adv_img),
160
+ # pn.Row(pn.Spacer(), download),
161
+ f" ###### Top {k_val} class predictions",
162
  )
163
 
164
  # Get likelihoods
 
173
  # Get top k values and indices
174
  vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val)
175
  label_bars = pn.Column()
176
+
177
  for idx, val in zip(idx_topk_clean, vals_topk_clean):
178
  prob = val.item()
179
  row_label = pn.widgets.StaticText(
180
+ name=f"{classes[idx]}", value=f"{prob:.2%}", align="center"
 
 
181
  )
182
  row_bar = pn.indicators.Progress(
183
  value=int(prob * 100),
184
  sizing_mode="stretch_width",
185
+ bar_color="success"
186
+ if prob > 0.7
187
+ else "warning", # Dynamic color based on value
188
  margin=(0, 10),
189
  design=pn.theme.Material,
190
  )
191
  label_bars.append(pn.Column(row_label, row_bar))
192
 
193
+ # for likelihood_tensor in likelihoods:
194
+ # # Get top
195
+ # vals_topk_clean, idx_topk_clean = torch.topk(likelihood_tensor, k=k_val)
196
+ # label_bars = pn.Column()
197
+ # for idx, val in zip(idx_topk_clean, vals_topk_clean):
198
+ # prob = val.item()
199
+ # row_label = pn.widgets.StaticText(
200
+ # name=f"{classes[idx]}", value=f"{prob:.2%}", align="center"
201
+ # )
202
+ # row_bar = pn.indicators.Progress(
203
+ # value=int(prob * 100),
204
+ # sizing_mode="stretch_width",
205
+ # bar_color="secondary",
206
+ # margin=(0, 10),
207
+ # design=pn.theme.Material,
208
+ # )
209
+ # label_bars.append(pn.Column(row_label, row_bar))
210
 
211
  label_bars_rows.append(label_bars)
212
+
213
  results.append(label_bars_rows)
214
 
215
  yield results
 
217
  except Exception as e:
218
  yield f"##### Something went wrong! \n {e}"
219
  return
220
+
221
  finally:
222
  main.disabled = False
223
 
 
236
 
237
  # Epsilon
238
  epsilon_slider = pn.widgets.FloatSlider(
239
+ name=r"$$\epsilon$$ parameter for FGSM",
240
+ start=0,
241
+ end=0.1,
242
+ step=0.005,
243
+ value=0.000,
244
+ format="1[.]000",
245
+ align="center",
246
+ max_width=500,
247
+ width_policy="max",
248
  )
249
 
250
+ # alpha_slider = pn.widgets.FloatSlider(
251
+ # name=r"$$\alpha$$ parameter for Gaussian noise",
252
+ # start=0,
253
+ # end=0.1,
254
+ # step=0.005,
255
+ # value=0.000,
256
+ # format="1[.]000",
257
+ # align="center",
258
+ # max_width=500,
259
+ # width_policy="max"
260
+
261
+ # )
262
+
263
+ # Regenerate button
264
+ regenerate = pn.widgets.Button(
265
+ name="Regenerate",
266
+ button_type="primary",
267
+ width_policy="min",
268
+ max_width=105,
269
+ )
270
  ############################################
271
 
272
  # Organize widgets in a column
273
  input_widgets = pn.Column(
274
  """
275
+ ###### Classify an image (png/jpeg) with a pre-trained [ResNet18](https://huggingface.co/microsoft/resnet-18) and generate an adversarial example.\n
276
+
277
+ Wondering where the class names come from? Find the list of ImageNet-1K classes [here.](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/)
278
 
279
+ *Please be patient with the application, it is running on a low-resource device.*
280
  """,
281
  file_input,
282
+ pn.Row(epsilon_slider, pn.Spacer(width_policy="min", max_width=25), regenerate),
283
  )
284
 
285
  # Add interactivity
286
  interactive_result = pn.panel(
287
  pn.bind(
288
+ process_inputs,
289
+ regenerate,
290
+ file_input.param.value,
291
+ epsilon_slider.param.value,
292
  ),
293
  height=600,
294
  )
295
 
296
  footer = pn.pane.Markdown(
297
  """
298
+ <br><br>
299
+
300
+ If the application is too slow for you, head over to the README to get this running locally.
301
  """
302
  )
303
 
 
313
  pn.template.BootstrapTemplate(
314
  title=title,
315
  main=main,
316
+ main_max_width="min(75%, 698px)",
317
  header_background="#101820",
318
  ).servable(title=title)
319