broadwell commited on
Commit
c3473c5
1 Parent(s): 26fb07e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +312 -240
  2. requirements.txt +1 -1
app.py CHANGED
@@ -27,7 +27,7 @@ from CLIP_Explainability.vit_cam import (
27
 
28
  from pytorch_grad_cam.grad_cam import GradCAM
29
 
30
- RUN_LITE = False # Load vision model for CAM viz explainability for M-CLIP only
31
 
32
  MAX_IMG_WIDTH = 500
33
  MAX_IMG_HEIGHT = 800
@@ -58,7 +58,10 @@ def encode_search_query(search_query, model_type):
58
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
59
  elif model_type == "J-CLIP (日本語 ViT)":
60
  t_text = st.session_state.ja_tokenizer(
61
- search_query, padding=True, return_tensors="pt"
 
 
 
62
  )
63
  text_encoded = st.session_state.ja_model.get_text_features(**t_text)
64
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
@@ -67,7 +70,7 @@ def encode_search_query(search_query, model_type):
67
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
68
 
69
  # Retrieve the feature vector
70
- return text_encoded
71
 
72
 
73
  def clip_search(search_query):
@@ -153,7 +156,9 @@ def load_image_features():
153
  def init():
154
  st.session_state.current_page = 1
155
 
156
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
157
  st.session_state.device = device
158
 
159
  # Load the open CLIP models
@@ -168,7 +173,7 @@ def init():
168
 
169
  st.session_state.ml_model = (
170
  pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
171
- )
172
  st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
173
 
174
  ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
@@ -193,7 +198,7 @@ def init():
193
 
194
  st.session_state.rn_model = legacy_multilingual_clip.load_model(
195
  "M-BERT-Base-69"
196
- )
197
  st.session_state.rn_tokenizer = BertTokenizer.from_pretrained(
198
  "bert-base-multilingual-cased"
199
  )
@@ -210,7 +215,6 @@ def init():
210
  st.session_state.vision_mode = "tiled"
211
  st.session_state.search_image_ids = []
212
  st.session_state.search_image_scores = {}
213
- st.session_state.activations_image = None
214
  st.session_state.text_table_df = None
215
 
216
  with st.spinner("Loading models and data, please wait..."):
@@ -221,233 +225,271 @@ if "images_info" not in st.session_state:
221
  init()
222
 
223
 
224
- def visualize_gradcam(viz_image_id):
225
- if "search_field_value" not in st.session_state:
226
- return
227
 
228
- header_cols = st.columns([80, 20], vertical_alignment="bottom")
229
- with header_cols[0]:
230
- st.title("Image + query details")
231
- with header_cols[1]:
232
- if st.button("Close"):
233
- st.rerun()
234
 
235
- st.markdown(
236
- f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}"
237
- )
238
 
239
- with st.spinner("Calculating..."):
240
- # info_text = st.text("Calculating activation regions...")
 
 
 
 
 
 
 
 
 
 
241
 
242
- image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
243
- image_response = requests.get(image_url)
244
- image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF"])
245
- image = image.convert("RGB")
 
246
 
247
- img_dim = 224
248
- if st.session_state.active_model == "M-CLIP (multilingual ViT)":
249
- img_dim = 240
250
- elif st.session_state.active_model == "Legacy (multilingual ResNet)":
251
- img_dim = 288
252
-
253
- orig_img_dims = image.size
254
-
255
- ##### If the features are based on tiled image slices
256
- tile_behavior = None
257
-
258
- if st.session_state.vision_mode == "tiled":
259
- scaled_dims = [img_dim, img_dim]
260
-
261
- if orig_img_dims[0] > orig_img_dims[1]:
262
- scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
263
- if scale_ratio > 1:
264
- scaled_dims = [scale_ratio * img_dim, img_dim]
265
- tile_behavior = "width"
266
- elif orig_img_dims[0] < orig_img_dims[1]:
267
- scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
268
- if scale_ratio > 1:
269
- scaled_dims = [img_dim, scale_ratio * img_dim]
270
- tile_behavior = "height"
271
-
272
- resized_image = image.resize(scaled_dims, Image.LANCZOS)
273
-
274
- if tile_behavior == "width":
275
- image_tiles = []
276
- for x in range(0, scale_ratio):
277
- box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
278
- image_tiles.append(resized_image.crop(box))
279
-
280
- elif tile_behavior == "height":
281
- image_tiles = []
282
- for y in range(0, scale_ratio):
283
- box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
284
- image_tiles.append(resized_image.crop(box))
285
-
286
- else:
287
- image_tiles = [resized_image]
288
-
289
- elif st.session_state.vision_mode == "stretched":
290
- image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
291
-
292
- else: # vision_mode == "cropped"
293
- if orig_img_dims[0] > orig_img_dims[1]:
294
- scale_factor = orig_img_dims[0] / orig_img_dims[1]
295
- resized_img_dims = (round(scale_factor * img_dim), img_dim)
296
- resized_img = image.resize(resized_img_dims)
297
- elif orig_img_dims[0] < orig_img_dims[1]:
298
- scale_factor = orig_img_dims[1] / orig_img_dims[0]
299
- resized_img_dims = (img_dim, round(scale_factor * img_dim))
300
- else:
301
- resized_img_dims = (img_dim, img_dim)
302
 
 
 
 
 
303
  resized_img = image.resize(resized_img_dims)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- left = round((resized_img_dims[0] - img_dim) / 2)
306
- top = round((resized_img_dims[1] - img_dim) / 2)
307
- x_right = round(resized_img_dims[0] - img_dim) - left
308
- x_bottom = round(resized_img_dims[1] - img_dim) - top
309
- right = resized_img_dims[0] - x_right
310
- bottom = resized_img_dims[1] - x_bottom
311
 
312
- # Crop the center of the image
313
- image_tiles = [resized_img.crop((left, top, right, bottom))]
 
314
 
315
- image_visualizations = []
 
 
 
 
 
 
 
 
316
 
317
- if st.session_state.active_model == "M-CLIP (multilingual ViT)":
318
- # Sometimes used for token importance viz
319
- tokenized_text = st.session_state.ml_tokenizer.tokenize(
320
- st.session_state.search_field_value
 
321
  )
322
 
323
- text_features = st.session_state.ml_model.forward(
324
- st.session_state.search_field_value, st.session_state.ml_tokenizer
 
 
 
 
325
  )
326
 
327
- image_model = st.session_state.ml_image_model
328
-
329
- for altered_image in image_tiles:
330
- p_image = (
331
- st.session_state.ml_image_preprocess(altered_image)
332
- .unsqueeze(0)
333
- .to(st.session_state.device)
334
- )
335
-
336
- vis_t = interpret_vit_overlapped(
337
- p_image.type(st.session_state.ml_image_model.dtype),
338
- text_features,
339
- image_model.visual,
340
- st.session_state.device,
341
- img_dim=img_dim,
342
- )
343
-
344
- image_visualizations.append(vis_t)
345
-
346
- elif st.session_state.active_model == "J-CLIP (日本語 ViT)":
347
- # Sometimes used for token importance viz
348
- tokenized_text = st.session_state.ja_tokenizer.tokenize(
349
- st.session_state.search_field_value
350
  )
351
 
352
- t_text = st.session_state.ja_tokenizer(
353
- st.session_state.search_field_value, return_tensors="pt"
 
 
 
 
354
  )
355
- text_features = st.session_state.ja_model.get_text_features(**t_text)
356
-
357
- image_model = st.session_state.ja_image_model
358
-
359
- for altered_image in image_tiles:
360
- p_image = (
361
- st.session_state.ja_image_preprocess(altered_image)
362
- .unsqueeze(0)
363
- .to(st.session_state.device)
364
- )
365
-
366
- vis_t = interpret_vit_overlapped(
367
- p_image.type(st.session_state.ja_image_model.dtype),
368
- text_features,
369
- image_model.visual,
370
- st.session_state.device,
371
- img_dim=img_dim,
372
- )
373
-
374
- image_visualizations.append(vis_t)
375
-
376
- else: # st.session_state.active_model == Legacy
377
- # Sometimes used for token importance viz
378
- tokenized_text = st.session_state.rn_tokenizer.tokenize(
379
- st.session_state.search_field_value
380
  )
381
 
382
- text_features = st.session_state.rn_model(
383
- st.session_state.search_field_value
 
 
 
 
 
384
  )
385
 
386
- image_model = st.session_state.rn_image_model
 
387
 
388
- for altered_image in image_tiles:
389
- p_image = (
390
- st.session_state.rn_image_preprocess(altered_image)
391
- .unsqueeze(0)
392
- .to(st.session_state.device)
393
- )
394
 
395
- vis_t = interpret_rn_overlapped(
396
- p_image.type(st.session_state.rn_image_model.dtype),
397
- text_features,
398
- image_model.visual,
399
- GradCAM,
400
- st.session_state.device,
401
- img_dim=img_dim,
402
- )
403
 
404
- image_visualizations.append(vis_t)
 
 
405
 
406
- transform = ToPILImage()
407
 
408
- vis_images = [transform(vis_t) for vis_t in image_visualizations]
409
 
410
- if st.session_state.vision_mode == "cropped":
411
- resized_img.paste(vis_images[0], (left, top))
412
- vis_images = [resized_img]
413
 
414
- if orig_img_dims[0] > orig_img_dims[1]:
415
- scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
416
- scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
417
- else:
418
- scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
419
- scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
420
 
421
- if tile_behavior == "width":
422
- vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
423
- for x, v_img in enumerate(vis_images):
424
- vis_image.paste(v_img, (x * img_dim, 0))
425
- st.session_state.activations_image = vis_image.resize(scaled_dims)
426
 
427
- elif tile_behavior == "height":
428
- vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
429
- for y, v_img in enumerate(vis_images):
430
- vis_image.paste(v_img, (0, y * img_dim))
431
- st.session_state.activations_image = vis_image.resize(scaled_dims)
432
 
433
- else:
434
- st.session_state.activations_image = vis_images[0].resize(scaled_dims)
435
 
436
- image_io = BytesIO()
437
- st.session_state.activations_image.save(image_io, "PNG")
438
- dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode(
439
- "ascii"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  )
441
 
442
- st.html(
443
- f"""<div style="display: flex; flex-direction: column; align-items: center;">
444
- <img src="{dataurl}" />
445
- </div>"""
 
446
  )
447
 
448
- tokenized_text = [tok.replace("▁", "") for tok in tokenized_text if tok != "▁"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  tokenized_text = [
450
- tok for tok in tokenized_text if tok not in ["s", "ed", "a", "the", "an", "ing"]
 
 
 
451
  ]
452
 
453
  if (
@@ -457,8 +499,7 @@ def visualize_gradcam(viz_image_id):
457
  "Calculate text importance (may take some time)",
458
  )
459
  ):
460
- search_tokens = []
461
- token_scores = []
462
 
463
  progress_text = f"Processing {len(tokenized_text)} text tokens"
464
  progress_bar = st.progress(0.0, text=progress_text)
@@ -466,34 +507,37 @@ def visualize_gradcam(viz_image_id):
466
  for t, tok in enumerate(tokenized_text):
467
  token = tok
468
 
469
- if st.session_state.active_model == "Legacy (multilingual ResNet)":
470
- word_rel = rn_perword_relevance(
471
- p_image,
472
- st.session_state.search_field_value,
473
- image_model,
474
- tokenize,
475
- GradCAM,
476
- st.session_state.device,
477
- token,
478
- data_only=True,
479
- img_dim=img_dim,
480
- )
481
- else:
482
- word_rel = vit_perword_relevance(
483
- p_image,
484
- st.session_state.search_field_value,
485
- image_model,
486
- tokenize,
487
- st.session_state.device,
488
- token,
489
- data_only=True,
490
- img_dim=img_dim,
491
- )
492
- avg_score = np.mean(word_rel)
493
- if avg_score == 0 or np.isnan(avg_score):
494
- continue
495
- search_tokens.append(token)
496
- token_scores.append(1 / avg_score)
 
 
 
497
 
498
  progress_bar.progress(
499
  (t + 1) / len(tokenized_text),
@@ -501,24 +545,48 @@ def visualize_gradcam(viz_image_id):
501
  )
502
  progress_bar.empty()
503
 
504
- normed_scores = torch.softmax(torch.tensor(token_scores), dim=0)
 
 
 
 
505
 
506
  token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
507
  st.session_state.text_table_df = pd.DataFrame(
508
- {"token": search_tokens, "importance": token_scores}
509
  )
510
 
511
  st.markdown("**Importance of each text token to relevance score**")
512
  st.table(st.session_state.text_table_df)
513
 
514
 
515
- def format_vision_mode(mode_stub):
516
- return mode_stub.capitalize()
 
517
 
518
 
519
- @st.dialog(" ", width="large")
520
- def image_modal(vis_image_id):
521
- visualize_gradcam(vis_image_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
 
524
  st.title("Explore Japanese visual aesthetics with CLIP models")
@@ -637,7 +705,7 @@ else:
637
  use_container_width=True,
638
  )
639
 
640
- controls = st.columns([35, 5, 35, 5, 20], gap="large", vertical_alignment="center")
641
  with controls[0]:
642
  im_per_pg = st.columns([30, 70], vertical_alignment="center")
643
  with im_per_pg[0]:
@@ -647,8 +715,6 @@ with controls[0]:
647
  "Images/page:", range(10, 50, 10), label_visibility="collapsed"
648
  )
649
  with controls[1]:
650
- st.empty()
651
- with controls[2]:
652
  im_per_row = st.columns([30, 70], vertical_alignment="center")
653
  with im_per_row[0]:
654
  st.markdown("**Images/row:**")
@@ -657,9 +723,7 @@ with controls[2]:
657
  "Images/row:", range(1, 6), value=5, label_visibility="collapsed"
658
  )
659
  num_batches = ceil(len(st.session_state.image_ids) / batch_size)
660
- with controls[3]:
661
- st.empty()
662
- with controls[4]:
663
  pager = st.columns([40, 60], vertical_alignment="center")
664
  with pager[0]:
665
  st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ")
@@ -672,6 +736,14 @@ with controls[4]:
672
  label_visibility="collapsed",
673
  key="current_page",
674
  )
 
 
 
 
 
 
 
 
675
 
676
 
677
  if len(st.session_state.search_image_ids) == 0:
@@ -708,7 +780,7 @@ for image_id in batch:
708
  if not RUN_LITE or st.session_state.active_model == "M-CLIP (multilingual ViT)":
709
  st.button(
710
  "Explain this",
711
- on_click=image_modal,
712
  args=[image_id],
713
  use_container_width=True,
714
  key=image_id,
 
27
 
28
  from pytorch_grad_cam.grad_cam import GradCAM
29
 
30
+ RUN_LITE = True # Load vision model for CAM viz explainability for M-CLIP only
31
 
32
  MAX_IMG_WIDTH = 500
33
  MAX_IMG_HEIGHT = 800
 
58
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
59
  elif model_type == "J-CLIP (日本語 ViT)":
60
  t_text = st.session_state.ja_tokenizer(
61
+ search_query,
62
+ padding=True,
63
+ return_tensors="pt",
64
+ device=st.session_state.device,
65
  )
66
  text_encoded = st.session_state.ja_model.get_text_features(**t_text)
67
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
 
70
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
71
 
72
  # Retrieve the feature vector
73
+ return text_encoded.to(st.session_state.device)
74
 
75
 
76
  def clip_search(search_query):
 
156
  def init():
157
  st.session_state.current_page = 1
158
 
159
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
160
+ device = "cpu"
161
+
162
  st.session_state.device = device
163
 
164
  # Load the open CLIP models
 
173
 
174
  st.session_state.ml_model = (
175
  pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
176
+ ).to(device)
177
  st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
178
 
179
  ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
 
198
 
199
  st.session_state.rn_model = legacy_multilingual_clip.load_model(
200
  "M-BERT-Base-69"
201
+ ).to(device)
202
  st.session_state.rn_tokenizer = BertTokenizer.from_pretrained(
203
  "bert-base-multilingual-cased"
204
  )
 
215
  st.session_state.vision_mode = "tiled"
216
  st.session_state.search_image_ids = []
217
  st.session_state.search_image_scores = {}
 
218
  st.session_state.text_table_df = None
219
 
220
  with st.spinner("Loading models and data, please wait..."):
 
225
  init()
226
 
227
 
228
+ def get_overlay_vis(image, img_dim, image_model):
229
+ orig_img_dims = image.size
 
230
 
231
+ ##### If the features are based on tiled image slices
232
+ tile_behavior = None
 
 
 
 
233
 
234
+ if st.session_state.vision_mode == "tiled":
235
+ scaled_dims = [img_dim, img_dim]
 
236
 
237
+ if orig_img_dims[0] > orig_img_dims[1]:
238
+ scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
239
+ if scale_ratio > 1:
240
+ scaled_dims = [scale_ratio * img_dim, img_dim]
241
+ tile_behavior = "width"
242
+ elif orig_img_dims[0] < orig_img_dims[1]:
243
+ scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
244
+ if scale_ratio > 1:
245
+ scaled_dims = [img_dim, scale_ratio * img_dim]
246
+ tile_behavior = "height"
247
+
248
+ resized_image = image.resize(scaled_dims, Image.LANCZOS)
249
 
250
+ if tile_behavior == "width":
251
+ image_tiles = []
252
+ for x in range(0, scale_ratio):
253
+ box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
254
+ image_tiles.append(resized_image.crop(box))
255
 
256
+ elif tile_behavior == "height":
257
+ image_tiles = []
258
+ for y in range(0, scale_ratio):
259
+ box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
260
+ image_tiles.append(resized_image.crop(box))
261
+
262
+ else:
263
+ image_tiles = [resized_image]
264
+
265
+ elif st.session_state.vision_mode == "stretched":
266
+ image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ else: # vision_mode == "cropped"
269
+ if orig_img_dims[0] > orig_img_dims[1]:
270
+ scale_factor = orig_img_dims[0] / orig_img_dims[1]
271
+ resized_img_dims = (round(scale_factor * img_dim), img_dim)
272
  resized_img = image.resize(resized_img_dims)
273
+ elif orig_img_dims[0] < orig_img_dims[1]:
274
+ scale_factor = orig_img_dims[1] / orig_img_dims[0]
275
+ resized_img_dims = (img_dim, round(scale_factor * img_dim))
276
+ else:
277
+ resized_img_dims = (img_dim, img_dim)
278
+
279
+ resized_img = image.resize(resized_img_dims)
280
+
281
+ left = round((resized_img_dims[0] - img_dim) / 2)
282
+ top = round((resized_img_dims[1] - img_dim) / 2)
283
+ x_right = round(resized_img_dims[0] - img_dim) - left
284
+ x_bottom = round(resized_img_dims[1] - img_dim) - top
285
+ right = resized_img_dims[0] - x_right
286
+ bottom = resized_img_dims[1] - x_bottom
287
 
288
+ # Crop the center of the image
289
+ image_tiles = [resized_img.crop((left, top, right, bottom))]
 
 
 
 
290
 
291
+ image_visualizations = []
292
+ image_features = []
293
+ image_similarities = []
294
 
295
+ if st.session_state.active_model == "M-CLIP (multilingual ViT)":
296
+ text_features = st.session_state.ml_model.forward(
297
+ st.session_state.search_field_value, st.session_state.ml_tokenizer
298
+ )
299
+
300
+ if st.session_state.device == "cpu":
301
+ text_features = text_features.float().to(st.session_state.device)
302
+ else:
303
+ text_features = text_features.to(st.session_state.device)
304
 
305
+ for altered_image in image_tiles:
306
+ p_image = (
307
+ st.session_state.ml_image_preprocess(altered_image)
308
+ .unsqueeze(0)
309
+ .to(st.session_state.device)
310
  )
311
 
312
+ vis_t, img_feats, similarity = interpret_vit_overlapped(
313
+ p_image.type(image_model.dtype),
314
+ text_features.type(image_model.dtype),
315
+ image_model.visual,
316
+ st.session_state.device,
317
+ img_dim=img_dim,
318
  )
319
 
320
+ image_visualizations.append(vis_t)
321
+ image_features.append(img_feats)
322
+ image_similarities.append(similarity.item())
323
+
324
+ elif st.session_state.active_model == "J-CLIP (日本語 ViT)":
325
+ t_text = st.session_state.ja_tokenizer(
326
+ st.session_state.search_field_value,
327
+ return_tensors="pt",
328
+ device=st.session_state.device,
329
+ )
330
+
331
+ text_features = st.session_state.ja_model.get_text_features(**t_text)
332
+
333
+ if st.session_state.device == "cpu":
334
+ text_features = text_features.float().to(st.session_state.device)
335
+ else:
336
+ text_features = text_features.to(st.session_state.device)
337
+
338
+ for altered_image in image_tiles:
339
+ p_image = (
340
+ st.session_state.ja_image_preprocess(altered_image)
341
+ .unsqueeze(0)
342
+ .to(st.session_state.device)
343
  )
344
 
345
+ vis_t, img_feats, similarity = interpret_vit_overlapped(
346
+ p_image.type(image_model.dtype),
347
+ text_features.type(image_model.dtype),
348
+ image_model.visual,
349
+ st.session_state.device,
350
+ img_dim=img_dim,
351
  )
352
+
353
+ image_visualizations.append(vis_t)
354
+ image_features.append(img_feats)
355
+ image_similarities.append(similarity.item())
356
+
357
+ else: # st.session_state.active_model == Legacy
358
+ text_features = st.session_state.rn_model(st.session_state.search_field_value)
359
+
360
+ if st.session_state.device == "cpu":
361
+ text_features = text_features.float().to(st.session_state.device)
362
+ else:
363
+ text_features = text_features.to(st.session_state.device)
364
+
365
+ for altered_image in image_tiles:
366
+ p_image = (
367
+ st.session_state.rn_image_preprocess(altered_image)
368
+ .unsqueeze(0)
369
+ .to(st.session_state.device)
 
 
 
 
 
 
 
370
  )
371
 
372
+ vis_t = interpret_rn_overlapped(
373
+ p_image.type(image_model.dtype),
374
+ text_features.type(image_model.dtype),
375
+ image_model.visual,
376
+ GradCAM,
377
+ st.session_state.device,
378
+ img_dim=img_dim,
379
  )
380
 
381
+ text_features_norm = text_features.norm(dim=-1, keepdim=True)
382
+ text_features_new = text_features / text_features_norm
383
 
384
+ image_feats = image_model.encode_image(p_image.type(image_model.dtype))
385
+ image_feats_norm = image_feats.norm(dim=-1, keepdim=True)
386
+ image_feats_new = image_feats / image_feats_norm
 
 
 
387
 
388
+ similarity = image_feats_new[0].dot(text_features_new[0])
 
 
 
 
 
 
 
389
 
390
+ image_visualizations.append(vis_t)
391
+ image_features.append(p_image)
392
+ image_similarities.append(similarity.item())
393
 
394
+ transform = ToPILImage()
395
 
396
+ vis_images = [transform(vis_t) for vis_t in image_visualizations]
397
 
398
+ if st.session_state.vision_mode == "cropped":
399
+ resized_img.paste(vis_images[0], (left, top))
400
+ vis_images = [resized_img]
401
 
402
+ if orig_img_dims[0] > orig_img_dims[1]:
403
+ scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
404
+ scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
405
+ else:
406
+ scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
407
+ scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
408
 
409
+ if tile_behavior == "width":
410
+ vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
411
+ for x, v_img in enumerate(vis_images):
412
+ vis_image.paste(v_img, (x * img_dim, 0))
413
+ activations_image = vis_image.resize(scaled_dims)
414
 
415
+ elif tile_behavior == "height":
416
+ vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
417
+ for y, v_img in enumerate(vis_images):
418
+ vis_image.paste(v_img, (0, y * img_dim))
419
+ activations_image = vis_image.resize(scaled_dims)
420
 
421
+ else:
422
+ activations_image = vis_images[0].resize(scaled_dims)
423
 
424
+ return activations_image, image_features, np.mean(image_similarities)
425
+
426
+
427
+ def visualize_gradcam(image):
428
+ if "search_field_value" not in st.session_state:
429
+ return
430
+
431
+ header_cols = st.columns([80, 20], vertical_alignment="bottom")
432
+ with header_cols[0]:
433
+ st.title("Image + query details")
434
+ with header_cols[1]:
435
+ if st.button("Close"):
436
+ st.rerun()
437
+
438
+ if st.session_state.active_model == "M-CLIP (multilingual ViT)":
439
+ img_dim = 240
440
+ image_model = st.session_state.ml_image_model
441
+ # Sometimes used for token importance viz
442
+ tokenized_text = st.session_state.ml_tokenizer.tokenize(
443
+ st.session_state.search_field_value
444
+ )
445
+ elif st.session_state.active_model == "Legacy (multilingual ResNet)":
446
+ img_dim = 288
447
+ image_model = st.session_state.rn_image_model
448
+ # Sometimes used for token importance viz
449
+ tokenized_text = st.session_state.rn_tokenizer.tokenize(
450
+ st.session_state.search_field_value
451
+ )
452
+ else: # J-CLIP
453
+ img_dim = 224
454
+ image_model = st.session_state.ja_image_model
455
+ # Sometimes used for token importance viz
456
+ tokenized_text = st.session_state.ja_tokenizer.tokenize(
457
+ st.session_state.search_field_value
458
  )
459
 
460
+ with st.spinner("Calculating..."):
461
+ # info_text = st.text("Calculating activation regions...")
462
+
463
+ activations_image, image_features, similarity_score = get_overlay_vis(
464
+ image, img_dim, image_model
465
  )
466
 
467
+ st.markdown(
468
+ f"**Query text:** {st.session_state.search_field_value} | **Approx. image relevance:** {round(similarity_score.item(), 3)}"
469
+ )
470
+
471
+ st.image(activations_image)
472
+
473
+ # image_io = BytesIO()
474
+ # activations_image.save(image_io, "PNG")
475
+ # dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode(
476
+ # "ascii"
477
+ # )
478
+
479
+ # st.html(
480
+ # f"""<div style="display: flex; flex-direction: column; align-items: center;">
481
+ # <img src="{dataurl}" />
482
+ # </div>"""
483
+ # )
484
+
485
+ tokenized_text = [
486
+ tok.replace("▁", "").replace("#", "") for tok in tokenized_text if tok != "▁"
487
+ ]
488
  tokenized_text = [
489
+ tok
490
+ for tok in tokenized_text
491
+ if tok
492
+ not in ["s", "ed", "a", "the", "an", "ing", "て", "に", "の", "は", "と", "た"]
493
  ]
494
 
495
  if (
 
499
  "Calculate text importance (may take some time)",
500
  )
501
  ):
502
+ scores_per_token = {}
 
503
 
504
  progress_text = f"Processing {len(tokenized_text)} text tokens"
505
  progress_bar = st.progress(0.0, text=progress_text)
 
507
  for t, tok in enumerate(tokenized_text):
508
  token = tok
509
 
510
+ for img_feats in image_features:
511
+ if st.session_state.active_model == "Legacy (multilingual ResNet)":
512
+ word_rel = rn_perword_relevance(
513
+ img_feats,
514
+ st.session_state.search_field_value,
515
+ image_model,
516
+ tokenize,
517
+ GradCAM,
518
+ st.session_state.device,
519
+ token,
520
+ data_only=True,
521
+ img_dim=img_dim,
522
+ )
523
+ else:
524
+ word_rel = vit_perword_relevance(
525
+ img_feats,
526
+ st.session_state.search_field_value,
527
+ image_model,
528
+ tokenize,
529
+ st.session_state.device,
530
+ token,
531
+ img_dim=img_dim,
532
+ )
533
+ avg_score = np.mean(word_rel)
534
+ if avg_score == 0 or np.isnan(avg_score):
535
+ continue
536
+
537
+ if token not in scores_per_token:
538
+ scores_per_token[token] = [1 / avg_score]
539
+ else:
540
+ scores_per_token[token].append(1 / avg_score)
541
 
542
  progress_bar.progress(
543
  (t + 1) / len(tokenized_text),
 
545
  )
546
  progress_bar.empty()
547
 
548
+ avg_scores_per_token = [
549
+ np.mean(scores_per_token[tok]) for tok in list(scores_per_token.keys())
550
+ ]
551
+
552
+ normed_scores = torch.softmax(torch.tensor(avg_scores_per_token), dim=0)
553
 
554
  token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
555
  st.session_state.text_table_df = pd.DataFrame(
556
+ {"token": list(scores_per_token.keys()), "importance": token_scores}
557
  )
558
 
559
  st.markdown("**Importance of each text token to relevance score**")
560
  st.table(st.session_state.text_table_df)
561
 
562
 
563
+ @st.dialog(" ", width="large")
564
+ def image_modal(image):
565
+ visualize_gradcam(image)
566
 
567
 
568
+ def vis_known_image(vis_image_id):
569
+ image_url = st.session_state.images_info.loc[vis_image_id]["image_url"]
570
+ image_response = requests.get(image_url)
571
+ image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF", "PNG"])
572
+ image = image.convert("RGB")
573
+
574
+ image_modal(image)
575
+
576
+
577
+ def vis_uploaded_image():
578
+ uploaded_file = st.session_state.uploaded_image
579
+ if uploaded_file is not None:
580
+ # To read file as bytes:
581
+ bytes_data = uploaded_file.getvalue()
582
+ image = Image.open(BytesIO(bytes_data), formats=["JPEG", "GIF", "PNG"])
583
+ image = image.convert("RGB")
584
+
585
+ image_modal(image)
586
+
587
+
588
+ def format_vision_mode(mode_stub):
589
+ return mode_stub.capitalize()
590
 
591
 
592
  st.title("Explore Japanese visual aesthetics with CLIP models")
 
705
  use_container_width=True,
706
  )
707
 
708
+ controls = st.columns([25, 25, 20, 35], gap="large", vertical_alignment="center")
709
  with controls[0]:
710
  im_per_pg = st.columns([30, 70], vertical_alignment="center")
711
  with im_per_pg[0]:
 
715
  "Images/page:", range(10, 50, 10), label_visibility="collapsed"
716
  )
717
  with controls[1]:
 
 
718
  im_per_row = st.columns([30, 70], vertical_alignment="center")
719
  with im_per_row[0]:
720
  st.markdown("**Images/row:**")
 
723
  "Images/row:", range(1, 6), value=5, label_visibility="collapsed"
724
  )
725
  num_batches = ceil(len(st.session_state.image_ids) / batch_size)
726
+ with controls[2]:
 
 
727
  pager = st.columns([40, 60], vertical_alignment="center")
728
  with pager[0]:
729
  st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ")
 
736
  label_visibility="collapsed",
737
  key="current_page",
738
  )
739
+ with controls[3]:
740
+ st.file_uploader(
741
+ "Upload an image",
742
+ type=["jpg", "jpeg", "gif", "png"],
743
+ key="uploaded_image",
744
+ label_visibility="collapsed",
745
+ on_change=vis_uploaded_image,
746
+ )
747
 
748
 
749
  if len(st.session_state.search_image_ids) == 0:
 
780
  if not RUN_LITE or st.session_state.active_model == "M-CLIP (multilingual ViT)":
781
  st.button(
782
  "Explain this",
783
+ on_click=vis_known_image,
784
  args=[image_id],
785
  use_container_width=True,
786
  key=image_id,
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  clip @ git+https://github.com/openai/CLIP.git
2
  ftfy==6.2.0
 
3
  multilingual_clip==1.0.10
4
  numpy==1.26
5
  opencv-python==4.10.0.84
@@ -7,7 +8,6 @@ pandas==2.1.2
7
  pillow==10.1.0
8
  requests==2.31.0
9
  sentencepiece==0.2.0
10
- streamlit
11
  torch==2.4.0
12
  torchvision==0.19.0
13
  transformers==4.35.0
 
1
  clip @ git+https://github.com/openai/CLIP.git
2
  ftfy==6.2.0
3
+ matplotlib==3.8.1
4
  multilingual_clip==1.0.10
5
  numpy==1.26
6
  opencv-python==4.10.0.84
 
8
  pillow==10.1.0
9
  requests==2.31.0
10
  sentencepiece==0.2.0
 
11
  torch==2.4.0
12
  torchvision==0.19.0
13
  transformers==4.35.0