Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
import torchvision | |
from PIL import Image | |
from sklearn.decomposition import NMF | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def relu_hook_function(module, grad_in, grad_out): | |
if isinstance(module, nn.ReLU): | |
return (F.relu(grad_in[0]),) | |
def blur_sailency(input_image): | |
return torchvision.transforms.functional.gaussian_blur(input_image, kernel_size=[11, 11], sigma=[5,5]) | |
def occlusion(img1, img2, model, w_size = 64, stride = 8, batch_size = 32): | |
measure = nn.CosineSimilarity(dim=-1) | |
output_size = int(((img2.size(-1) - w_size) / stride) + 1) | |
out1_condition, out2_condition = model(img1), model(img2) | |
images1 = [] | |
images2 = [] | |
for i in range(output_size): | |
for j in range(output_size): | |
start_i, start_j = i * stride, j * stride | |
image1 = img1.clone().detach() | |
image2 = img2.clone().detach() | |
image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 | |
image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 | |
images1.append(image1) | |
images2.append(image2) | |
images1 = torch.cat(images1, dim=0).to(device) | |
images2 = torch.cat(images2, dim=0).to(device) | |
score_map1 = [] | |
score_map2 = [] | |
assert images1.shape[0] == images2.shape[0] | |
for b in range(0, images2.shape[0], batch_size): | |
with torch.no_grad(): | |
out1 = model(images1[b : b + batch_size, :]) | |
out2 = model(images2[b : b + batch_size, :]) | |
score_map1.append(measure(out1, out2_condition)) # try torch.mm(out2_condition, out1.t())[0] | |
score_map2.append(measure(out1_condition, out2)) # try torch.mm(out1_condition, out2.t())[0] | |
score_map1 = torch.cat(score_map1, dim = 0) | |
score_map2 = torch.cat(score_map2, dim = 0) | |
assert images2.shape[0] == score_map2.shape[0] == score_map1.shape[0] | |
heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy() | |
heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy() | |
base_score = measure(out1_condition, out2_condition) | |
heatmap1 = (heatmap1 - base_score.item()) * -1 # or base_score.item() - heatmap1. The higher the drop, the better | |
heatmap2 = (heatmap2 - base_score.item()) * -1 # or base_score.item() - heatmap2. The higher the drop, the better | |
return heatmap1, heatmap2 | |
def occlusion_context_agnositc(img1, img2, model, w_size = 64, stride = 8, batch_size = 32): | |
measure = nn.CosineSimilarity(dim=-1) | |
output_size = int(((img2.size(-1) - w_size) / stride) + 1) | |
out1_condition, out2_condition = model(img1), model(img2) | |
images1_occlude_mask = [] | |
images2_occlude_mask = [] | |
for i in range(output_size): | |
for j in range(output_size): | |
start_i, start_j = i * stride, j * stride | |
image1 = img1.clone().detach() | |
image2 = img2.clone().detach() | |
image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 | |
image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 | |
images1_occlude_mask.append(image1) | |
images2_occlude_mask.append(image2) | |
images1_occlude_mask = torch.cat(images1_occlude_mask, dim=0).to(device) | |
images2_occlude_mask = torch.cat(images2_occlude_mask, dim=0).to(device) | |
images1_occlude_backround = [] | |
images2_occlude_backround = [] | |
copy_img1 = img1.clone().detach() | |
copy_img2 = img2.clone().detach() | |
for i in range(output_size): | |
for j in range(output_size): | |
start_i, start_j = i * stride, j * stride | |
image1 = torch.zeros_like(img1) | |
image2 = torch.zeros_like(img2) | |
image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] | |
image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] | |
images1_occlude_backround.append(image1) | |
images2_occlude_backround.append(image2) | |
images1_occlude_backround = torch.cat(images1_occlude_backround, dim=0).to(device) | |
images2_occlude_backround = torch.cat(images2_occlude_backround, dim=0).to(device) | |
score_map1 = [] | |
score_map2 = [] | |
assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0] | |
for b in range(0, images1_occlude_mask.shape[0], batch_size): | |
with torch.no_grad(): | |
out1_mask = model(images1_occlude_mask[b : b + batch_size, :]) | |
out2_mask = model(images2_occlude_mask[b : b + batch_size, :]) | |
out1_backround = model(images1_occlude_backround[b : b + batch_size, :]) | |
out2_backround = model(images2_occlude_backround[b : b + batch_size, :]) | |
out1 = out1_backround - out1_mask | |
out2 = out2_backround - out2_mask | |
score_map1.append(measure(out1, out2_condition)) # or torch.mm(out2_condition, out1.t())[0] | |
score_map2.append(measure(out1_condition, out2)) # or torch.mm(out1_condition, out2.t())[0] | |
score_map1 = torch.cat(score_map1, dim = 0) | |
score_map2 = torch.cat(score_map2, dim = 0) | |
assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0] == score_map2.shape[0] == score_map1.shape[0] | |
heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy() | |
heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy() | |
heatmap1 = (heatmap1 - heatmap1.min()) / (heatmap1.max() - heatmap1.min()) | |
heatmap2 = (heatmap2 - heatmap2.min()) / (heatmap2.max() - heatmap2.min()) | |
return heatmap1, heatmap2 | |
def pairwise_occlusion(img1, img2, model, batch_size, erase_scale, erase_ratio, num_erases): | |
measure = nn.CosineSimilarity(dim=-1) | |
out1_condition, out2_condition = model(img1), model(img2) | |
baseline = measure(out1_condition, out2_condition).detach() | |
# a bit sensitive to scale and ratio. erase_scale is from (scale[0] * 100) % to (scale[1] * 100) % | |
random_erase = transforms.RandomErasing(p=1.0, scale=erase_scale, ratio=erase_ratio) | |
image1 = img1.clone().detach() | |
image2 = img2.clone().detach() | |
images1 = [] | |
images2 = [] | |
for _ in range(num_erases): | |
images1.append(random_erase(image1)) | |
images2.append(random_erase(image2)) | |
images1 = torch.cat(images1, dim=0).to(device) | |
images2 = torch.cat(images2, dim=0).to(device) | |
sims = [] | |
weights1 = [] | |
weights2 = [] | |
for b in range(0, images2.shape[0], batch_size): | |
with torch.no_grad(): | |
out1 = model(images1[b : b + batch_size, :]) | |
out2 = model(images2[b : b + batch_size, :]) | |
sims.append(measure(out1, out2)) | |
weights1.append(out1.norm(dim=-1)) | |
weights2.append(out2.norm(dim=-1)) | |
sims = torch.cat(sims, dim = 0) | |
weights1, weights2 = torch.cat(weights1, dim = 0).cpu().numpy(), torch.cat(weights2, dim = 0).cpu().numpy() | |
weights = list(zip(weights1, weights2)) | |
sims = baseline - sims # the higher the drop, the better | |
sims = F.softmax(sims, dim = -1) | |
sims = sims.cpu().numpy() | |
assert sims.shape[0] == images1.shape[0] == images2.shape[0] | |
A1 = np.zeros((224, 224)) | |
A2 = np.zeros((224, 224)) | |
for n in range(images1.shape[0]): | |
im1_2d = images1[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1) | |
im2_2d = images2[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1) | |
joint_similarity = sims[n] | |
weight = weights[n] | |
if weight[0] < weight[1]: | |
A1[im1_2d == 0] += joint_similarity | |
else: | |
A2[im2_2d == 0] += joint_similarity | |
A1 = A1 / (np.max(A1) + 1e-9) | |
A2 = A2 / (np.max(A2) + 1e-9) | |
return A1, A2 | |
def tv_reg(img, l1 = True): | |
diff_i = (img[:, :, :, 1:] - img[:, :, :, :-1]) | |
diff_j = (img[:, :, 1:, :] - img[:, :, :-1, :]) | |
if l1: | |
return diff_i.abs().sum() + diff_j.abs().sum() | |
else: | |
return diff_i.pow(2).sum() + diff_j.pow(2).sum() | |
def synthesize(ssl_model, model_type, img1, img_cls_layer, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network): | |
if model_type == 'imagenet': | |
reduce_lr = False | |
model = torchvision.models.resnet50(pretrained=True) | |
model = list(model.children())[:img_cls_layer] | |
model = nn.Sequential(*model).to(device) | |
model.eval() | |
else: | |
reduce_lr = True | |
shift_layer = 3 if network == 'simclrv2' else 0 | |
equivalent_layer = img_cls_layer - shift_layer | |
model = list(ssl_model.encoder.net.children())[:equivalent_layer] | |
model = nn.Sequential(*model).to(device) | |
model.eval() | |
opt_img = (init_scale * torch.randn(1, 3, 224, 224)).to(device).requires_grad_() | |
target_feats = model(img1).detach() | |
optimizer = torch.optim.SGD([opt_img], lr=lr, momentum=0.9) | |
for i in range(201): | |
opt_img.data = opt_img.data.clip(0,1) | |
optimizer.zero_grad() | |
output = model(opt_img) | |
l2_loss = l2_weight * ((output - target_feats) ** 2).sum() / (target_feats ** 2).sum() | |
reg_alpha = alpha_weight * (opt_img ** alpha_power).sum() | |
reg_total_variation = tv_weight * tv_reg(opt_img, l1 = False) | |
loss = l2_loss + reg_alpha + reg_total_variation | |
loss.backward() | |
optimizer.step() | |
if reduce_lr and i % 40 == 0: | |
for param_group in optimizer.param_groups: | |
param_group['lr'] *= 1/10 | |
return opt_img | |
def get_difference(ssl_model, baseline, image, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network): | |
imagenet_images = [] | |
ssl_images = [] | |
for lay in range(4,7): | |
image_net_image = synthesize(ssl_model, baseline, image, lay, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network).detach().clone() | |
ssl_image = synthesize(ssl_model, 'ssl', image, lay, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network).detach().clone() | |
imagenet_images.append(image_net_image) | |
ssl_images.append(ssl_image) | |
return imagenet_images, ssl_images | |
def create_mixed_images(transform_type, ig_transforms, step, img_path, add_noise): | |
img = Image.open(img_path).convert('RGB') | |
img1 = ig_transforms['pure'](img).unsqueeze(0).to(device) | |
img2 = ig_transforms[transform_type](img).unsqueeze(0).to(device) | |
lambdas = np.arange(1,0,-step) | |
mixed_images = [] | |
for l,lam in enumerate(lambdas): | |
mixed_img = lam * img1 + (1 - lam) * img2 | |
mixed_images.append(mixed_img) | |
if add_noise: | |
sigma = 0.15 / (torch.max(img1) - torch.min(img1)).item() | |
mixed_images = [im + torch.zeros_like(im).normal_(0, sigma) if (n>0) and (n<len(mixed_images)-1) else im for n,im in enumerate(mixed_images)] | |
return mixed_images | |
def averaged_transforms(guided, ssl_model, mixed_images, blur_output): | |
measure = nn.CosineSimilarity(dim=-1) | |
if guided: | |
handles = [] | |
for i, module in enumerate(ssl_model.modules()): | |
if isinstance(module, nn.ReLU): | |
handles.append(module.register_backward_hook(relu_hook_function)) | |
grads1 = [] | |
grads2 = [] | |
for xbar_image in mixed_images[1:]: | |
input_image1 = mixed_images[0].clone().requires_grad_() | |
input_image2 = xbar_image.clone().requires_grad_() | |
if input_image1.grad is not None: | |
input_image1.grad.data.zero_() | |
input_image2.grad.data.zero_() | |
score = measure(ssl_model(input_image1), ssl_model(input_image2)) | |
score.backward() | |
grads1.append(input_image1.grad.data) | |
grads2.append(input_image2.grad.data) | |
grads1 = torch.cat(grads1).mean(0).unsqueeze(0) | |
grads2 = torch.cat(grads2).mean(0).unsqueeze(0) | |
sailency1, _ = torch.max((mixed_images[0] * grads1).abs(), dim=1) | |
sailency2, _ = torch.max((mixed_images[-1] * grads2).abs(), dim=1) | |
if guided: # remove handles after finishing | |
for handle in handles: | |
handle.remove() | |
if blur_output: | |
sailency1 = blur_sailency(sailency1) | |
sailency2 = blur_sailency(sailency2) | |
return sailency1, sailency2 | |
def sailency(guided, ssl_model, img1, img2, blur_output): | |
measure = nn.CosineSimilarity(dim=-1) | |
if guided: | |
handles = [] | |
for i, module in enumerate(ssl_model.modules()): | |
if isinstance(module, nn.ReLU): | |
handles.append(module.register_backward_hook(relu_hook_function)) | |
input_image1 = img1.clone().requires_grad_() | |
input_image2 = img2.clone().requires_grad_() | |
score = measure(ssl_model(input_image1), ssl_model(input_image2)) | |
score.backward() | |
grads1 = input_image1.grad.data | |
grads2 = input_image2.grad.data | |
sailency1, _ = torch.max((img1 * grads1).abs(), dim=1) | |
sailency2, _ = torch.max((img2 * grads2).abs(), dim=1) | |
if guided: # remove handles after finishing | |
for handle in handles: | |
handle.remove() | |
if blur_output: | |
sailency1 = blur_sailency(sailency1) | |
sailency2 = blur_sailency(sailency2) | |
return sailency1, sailency2 | |
def smooth_grad(guided, ssl_model, img1, img2, blur_output, steps = 50): | |
measure = nn.CosineSimilarity(dim=-1) | |
sigma = 0.15 / (torch.max(img1) - torch.min(img1)).item() | |
if guided: | |
handles = [] | |
for i, module in enumerate(ssl_model.modules()): | |
if isinstance(module, nn.ReLU): | |
handles.append(module.register_backward_hook(relu_hook_function)) | |
noise_images1 = [] | |
noise_images2 = [] | |
for _ in range(steps): | |
noise = torch.zeros_like(img1).normal_(0, sigma) | |
noise_images1.append(img1 + noise) | |
noise_images2.append(img2 + noise) | |
grads1 = [] | |
grads2 = [] | |
for n1, n2 in zip(noise_images1, noise_images2): | |
input_image1 = n1.clone().requires_grad_() | |
input_image2 = n2.clone().requires_grad_() | |
if input_image1.grad is not None: | |
input_image1.grad.data.zero_() | |
input_image2.grad.data.zero_() | |
score = measure(ssl_model(input_image1), ssl_model(input_image2)) | |
score.backward() | |
grads1.append(input_image1.grad.data) | |
grads2.append(input_image2.grad.data) | |
grads1 = torch.cat(grads1).mean(0).unsqueeze(0) | |
grads2 = torch.cat(grads2).mean(0).unsqueeze(0) | |
sailency1, _ = torch.max((img1 * grads1 ).abs(), dim=1) | |
sailency2, _ = torch.max((img2 * grads2).abs(), dim=1) | |
if guided: # remove handles after finishing | |
for handle in handles: | |
handle.remove() | |
if blur_output: | |
sailency1 = blur_sailency(sailency1) | |
sailency2 = blur_sailency(sailency2) | |
return sailency1, sailency2 | |
def get_sample_dataset(img_path, num_augments, batch_size, no_shift_transforms, ssl_model, n_components): | |
measure = nn.CosineSimilarity(dim=-1) | |
img = Image.open(img_path).convert('RGB') | |
no_shift_aug = transforms.Compose([no_shift_transforms['aug'], | |
transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))]) | |
augments2 = [no_shift_aug(img).unsqueeze(0) for _ in range(num_augments)] | |
data_samples1 = no_shift_transforms['pure'](img).unsqueeze(0).expand(num_augments, -1, -1, -1).to(device) | |
data_samples2 = torch.cat(augments2).to(device) | |
labels = [] | |
feats_invariance = [] | |
for b in range(0, data_samples1.shape[0], batch_size): | |
with torch.no_grad(): | |
out1 = ssl_model(data_samples1[b : b + batch_size, :]) | |
out2 = ssl_model(data_samples2[b : b + batch_size, :]) | |
labels.append(measure(out1, out2)) | |
feats_invariance.append(F.relu(out2)) | |
data_labels = torch.cat(labels).unsqueeze(-1).to(device) | |
feats_invariance = torch.cat(feats_invariance).to(device) | |
nmf_model = NMF(n_components=n_components, init='random') | |
# (T, 2048) = W.H = (2048,N) . (N,T), where H is the matrix representing the features of each transform | |
H = nmf_model.fit_transform(feats_invariance.cpu().numpy()) | |
labels_invariance = torch.from_numpy(H.mean(1)).unsqueeze(-1).to(device) | |
return data_samples1, data_samples2, data_labels, labels_invariance | |
def pixel_invariance(data_samples1, data_samples2, data_labels, labels_invariance, resize_transform, size, epochs, learning_rate, l1_weight, zero_small_values, blur_output, nmf_weight): | |
""" | |
size: resize the image to that when training the surrogate. Later we upsize | |
epochs: number of epochs to train the surrogate model | |
learning_rate: learning rate to train the surrogate model | |
l1_weight: if not None, enables l1 regularization (sparsity) | |
""" | |
x1 = resize_transform((size, size))(data_samples1) # (num_samples, 3, size, size) | |
x2 = resize_transform((size, size))(data_samples2) # (num_samples, 3, size, size) | |
x1 = x1.reshape(x1.size(0), -1).to(device) | |
x2 = x2.reshape(x2.size(0), -1).to(device) | |
surrogate = nn.Linear(size * size * 3, 1).to(device) | |
criterion = nn.BCEWithLogitsLoss(reduction = 'sum') | |
invariance_criterion = nn.MSELoss() | |
optimizer = torch.optim.SGD(surrogate.parameters(), lr=learning_rate) | |
for epoch in range(epochs): | |
pred1, pred2 = surrogate(x1), surrogate(x2) | |
preds = (pred1 + pred2) / 2 | |
loss = criterion(preds, data_labels) | |
loss += nmf_weight * invariance_criterion(torch.sigmoid(preds), labels_invariance) | |
if l1_weight is not None: | |
loss += l1_weight * sum(p.abs().sum() for p in surrogate.parameters()) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
heatmap = surrogate.weight.reshape(3, size, size) | |
heatmap, _ = torch.max(heatmap, 0) | |
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) | |
if zero_small_values: | |
heatmap[heatmap < 0.5] = 0 | |
if blur_output: | |
heatmap = blur_sailency(heatmap.unsqueeze(0)).squeeze(0) | |
return heatmap | |
class GradCAM(nn.Module): | |
def __init__(self, ssl_model): | |
super(GradCAM, self).__init__() | |
self.gradients = {} | |
self.features = {} | |
self.feature_extractor = ssl_model.encoder.net | |
self.contrastive_head = ssl_model.contrastive_head | |
self.measure = nn.CosineSimilarity(dim=-1) | |
def save_grads(self, img_index): | |
def hook(grad): | |
self.gradients[img_index] = grad.detach() | |
return hook | |
def save_features(self, img_index, feats): | |
self.features[img_index] = feats.detach() | |
def forward(self, img1, img2): | |
features1 = self.feature_extractor(img1) | |
features2 = self.feature_extractor(img2) | |
self.save_features('1', features1) | |
self.save_features('2', features2) | |
h1 = features1.register_hook(self.save_grads('1')) | |
h2 = features2.register_hook(self.save_grads('2')) | |
out1, out2 = features1.mean(dim=[2, 3]), features2.mean(dim=[2, 3]) | |
out1, out2 = self.contrastive_head(out1), self.contrastive_head(out2) | |
score = self.measure(out1, out2) | |
return score | |
def weight_activation(feats, grads): | |
cam = feats * F.relu(grads) | |
cam = torch.sum(cam, dim=1).squeeze().cpu().detach().numpy() | |
return cam | |
def get_gradcam(ssl_model, img1, img2): | |
grad_cam = GradCAM(ssl_model).to(device) | |
score = grad_cam(img1, img2) | |
grad_cam.zero_grad() | |
score.backward() | |
cam1 = weight_activation(grad_cam.features['1'], grad_cam.gradients['1']) | |
cam2 = weight_activation(grad_cam.features['2'], grad_cam.gradients['2']) | |
return cam1, cam2 | |
def get_interactioncam(ssl_model, img1, img2, reduction, grad_interact = False): | |
grad_cam = GradCAM(ssl_model).to(device) | |
score = grad_cam(img1, img2) | |
grad_cam.zero_grad() | |
score.backward() | |
G1 = grad_cam.gradients['1'] | |
G2 = grad_cam.gradients['2'] | |
if grad_interact: | |
B, D, H, W = G1.size() | |
G1_ = G1.permute(0,2,3,1).view(B, H * W, D) | |
G2_ = G2.permute(0,2,3,1).view(B, H * W, D) | |
G_ = torch.bmm(G1_.permute(0,2,1), G2_) # (B, D, D) | |
G1, _ = torch.max(G_, dim = -1) # (B, D) | |
G2, _ = torch.max(G_, dim = 1) # (B, D) | |
G1 = G1.unsqueeze(-1).unsqueeze(-1) | |
G2 = G2.unsqueeze(-1).unsqueeze(-1) | |
if reduction == 'mean': | |
joint_weight = grad_cam.features['1'].mean([2,3]) * grad_cam.features['2'].mean([2,3]) | |
elif reduction == 'max': | |
max_pooled1 = F.max_pool2d(grad_cam.features['1'], kernel_size=grad_cam.features['1'].size()[2:]).squeeze(-1).squeeze(-1) | |
max_pooled2 = F.max_pool2d(grad_cam.features['2'], kernel_size=grad_cam.features['2'].size()[2:]).squeeze(-1).squeeze(-1) | |
joint_weight = max_pooled1 * max_pooled2 | |
else: | |
B, D, H, W = grad_cam.features['1'].size() | |
reshaped1 = grad_cam.features['1'].permute(0,2,3,1).reshape(B, H * W, D) | |
reshaped2 = grad_cam.features['2'].permute(0,2,3,1).reshape(B, H * W, D) | |
features1_query, features2_query = reshaped1.mean(1).unsqueeze(1), reshaped2.mean(1).unsqueeze(1) | |
attn1 = (features1_query @ reshaped1.transpose(-2, -1)).softmax(dim=-1) | |
attn2 = (features2_query @ reshaped2.transpose(-2, -1)).softmax(dim=-1) | |
att_reduced1 = (attn1 @ reshaped1).squeeze(1) | |
att_reduced2 = (attn2 @ reshaped2).squeeze(1) | |
joint_weight = att_reduced1 * att_reduced2 | |
joint_weight = joint_weight.unsqueeze(-1).unsqueeze(-1).expand_as(grad_cam.features['1']) | |
feats1 = grad_cam.features['1'] * joint_weight | |
feats2 = grad_cam.features['2'] * joint_weight | |
cam1 = weight_activation(feats1, G1) | |
cam2 = weight_activation(feats2, G2) | |
return cam1, cam2 |