patrickramos commited on
Commit
22d82a8
1 Parent(s): cffaaa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from transformers import ViTFeatureExtractor, ViTModel
4
  from skops import hub_utils
5
  from einops import reduce
 
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
 
@@ -57,14 +58,18 @@ def classify_and_heatmap(input_img):
57
  ax_orig_img = fig.add_subplot(gs[:, 0])
58
 
59
  # plot original image
60
- img = feature_extractor.to_pil_image(
61
  inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1)
62
  )
63
  ax_orig_img.imshow(img)
64
  ax_orig_img.axis('off')
65
 
66
  # plot patch contributions
67
- patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
 
 
 
 
68
  vmin = patch_contributions.min()
69
  vmax = patch_contributions.max()
70
 
 
3
  from transformers import ViTFeatureExtractor, ViTModel
4
  from skops import hub_utils
5
  from einops import reduce
6
+ from torchvision.transforms.functional import to_pil_image
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
9
 
 
58
  ax_orig_img = fig.add_subplot(gs[:, 0])
59
 
60
  # plot original image
61
+ img = to_pil_image(
62
  inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1)
63
  )
64
  ax_orig_img.imshow(img)
65
  ax_orig_img.axis('off')
66
 
67
  # plot patch contributions
68
+ patch_contributions = (
69
+ logistic_regression.coef_ \
70
+ @ patch_embeddings.T.numpy() \
71
+ + logistic_regression.intercept_.reshape(-1, 1) / (num_patches_side ** 2)
72
+ ).reshape(-1, num_patches_side, num_patches_side)
73
  vmin = patch_contributions.min()
74
  vmax = patch_contributions.max()
75