sayakpaul HF staff commited on
Commit
2a9ec74
1 Parent(s): 08c975a

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +89 -0
utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from PIL import Image
7
+ from tensorflow import keras
8
+
9
+ RESOLUTION = 224
10
+
11
+ crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
12
+ norm_layer = keras.layers.Normalization(
13
+ mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
14
+ variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
15
+ )
16
+ rescale_layer = keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1)
17
+
18
+
19
+ def preprocess_image(orig_image: Image, model_type: str, size=RESOLUTION):
20
+ """Image preprocessing utility."""
21
+ # Turn the image into a numpy array and add batch dim.
22
+ image = np.array(orig_image)
23
+ image = tf.expand_dims(image, 0)
24
+
25
+ # If model type is vit rescale the image to [-1, 1].
26
+ if model_type == "original_vit":
27
+ image = rescale_layer(image)
28
+
29
+ # Resize the image using bicubic interpolation.
30
+ resize_size = int((256 / 224) * size)
31
+ image = tf.image.resize(image, (resize_size, resize_size), method="bicubic")
32
+
33
+ # Crop the image.
34
+ preprocessed_image = crop_layer(image)
35
+
36
+ # If model type is DeiT or DINO normalize the image.
37
+ if model_type != "original_vit":
38
+ image = norm_layer(preprocessed_image)
39
+
40
+ return orig_image, preprocessed_image.numpy()
41
+
42
+
43
+ def attention_rollout_map(
44
+ image: Image, attention_score_dict: Dict[str, np.ndarray], model_type: str
45
+ ):
46
+ """Computes attention rollout results.
47
+
48
+ Reference:
49
+ https://arxiv.org/abs/2005.00928
50
+
51
+ Code copied and modified from here:
52
+ https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb
53
+ """
54
+ num_cls_tokens = 2 if "distilled" in model_type else 1
55
+
56
+ # Stack the individual attention matrices from individual transformer blocks.
57
+ attn_mat = tf.stack(
58
+ [attention_score_dict[k] for k in attention_score_dict.keys()]
59
+ )
60
+ attn_mat = tf.squeeze(attn_mat, axis=1)
61
+
62
+ # Average the attention weights across all heads.
63
+ attn_mat = tf.reduce_mean(attn_mat, axis=1)
64
+
65
+ # To account for residual connections, we add an identity matrix to the
66
+ # attention matrix and re-normalize the weights.
67
+ residual_attn = tf.eye(attn_mat.shape[1])
68
+ aug_attn_mat = attn_mat + residual_attn
69
+ aug_attn_mat = (
70
+ aug_attn_mat / tf.reduce_sum(aug_attn_mat, axis=-1)[..., None]
71
+ )
72
+ aug_attn_mat = aug_attn_mat.numpy()
73
+
74
+ # Recursively multiply the weight matrices.
75
+ joint_attentions = np.zeros(aug_attn_mat.shape)
76
+ joint_attentions[0] = aug_attn_mat[0]
77
+
78
+ for n in range(1, aug_attn_mat.shape[0]):
79
+ joint_attentions[n] = np.matmul(
80
+ aug_attn_mat[n], joint_attentions[n - 1]
81
+ )
82
+
83
+ # Attention from the output token to the input space.
84
+ v = joint_attentions[-1]
85
+ grid_size = int(np.sqrt(aug_attn_mat.shape[-1]))
86
+ mask = v[0, num_cls_tokens:].reshape(grid_size, grid_size)
87
+ mask = cv2.resize(mask / mask.max(), image.size)[..., np.newaxis]
88
+ result = (mask * image).astype("uint8")
89
+ return result