Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import math | |
from PIL import Image, ImageDraw, ImageFont | |
import logging | |
import os | |
import pandas as pd | |
import csv | |
import pickle | |
import numpy as np | |
from torch.nn import BCELoss | |
from torch.nn import functional as F | |
import math | |
import numbers | |
from typing import List | |
def get_all_attention_64(attn_maps_down, attn_maps_mid , attn_maps_up, res = 16): | |
result = [] | |
for attn_map_integrated in attn_maps_up: | |
if attn_map_integrated == []: continue | |
attn_map = attn_map_integrated.squeeze(0) | |
# print(attn_map.shape) | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
# print(H) | |
if H == res: | |
item = attn_map.reshape(-1, res, res, attn_map.shape[-1] ) | |
item = item.permute(0, 3, 1, 2) | |
item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) | |
result.append(item) | |
for attn_map_integrated in attn_maps_mid: | |
attn_map = attn_map_integrated.squeeze(0) | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
# print(H) | |
if (H==8): | |
item = attn_map.reshape(-1, 8, 8, attn_map.shape[-1] ) | |
item = item.permute(0, 3, 1, 2) | |
item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) | |
result.append(item) | |
for attn_map_integrated in attn_maps_down: | |
if attn_map_integrated == []: continue | |
attn_map = attn_map_integrated.squeeze(0) | |
if attn_map == []: continue | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
if H == res: | |
item = attn_map.reshape(-1, res, res, attn_map.shape[-1] ) | |
item = item.permute(0, 3, 1, 2) | |
item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) | |
result.append(item) | |
# print('RES LENGTH', len(result)) | |
# for maps in result: | |
# print(maps.shape) | |
result = torch.cat(result, dim=0) | |
result = result.sum(0) / result.shape[0] | |
return result | |
def compute_loco_v2(attn_maps_down, attn_maps_mid, attn_maps_up, bboxes, object_positions, smooth_attn=True, topk = 0.8): | |
loss = 0. | |
pad_loss = 0. | |
total_fg_map = torch.zeros(size=(64, 64)).cuda() | |
alpha = 0.2 | |
beta = 0.8 | |
object_number = len(bboxes) | |
if object_number == 0: | |
return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float() | |
attn16 = get_all_attention_64(attn_maps_down[-1]+ attn_maps_down[-2], attn_maps_mid, attn_maps_up[0]+attn_maps_up[1], 16) | |
all_attn = [attn16] | |
max_loss = 0 | |
for attn_map in all_attn: | |
sum_in = 0. | |
sum_out = 0. | |
i, j, k = attn_map.shape | |
H = W = i | |
for obj_idx in range(object_number): | |
obj_loss = 0 | |
mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) | |
for obj_box in bboxes[obj_idx]: | |
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
mask[y_min: y_max, x_min: x_max] = 1 | |
total_fg_map[y_min: y_max, x_min: x_max] = 1 | |
for obj_position in [object_positions[obj_idx]]: | |
ca_map_obj = attn_map[:, :, obj_position].sum(-1) | |
ca_map_obj = ca_map_obj.reshape(H, W) | |
norm_ca_map_obj = ca_map_obj / ca_map_obj.max() | |
norm_ca_map_obj = norm_ca_map_obj.reshape(H, W) | |
sum_in += (norm_ca_map_obj * mask).sum() | |
sum_out += (norm_ca_map_obj * (1 - mask)).sum() | |
loss += (obj_loss/len(object_positions[obj_idx])) | |
sot_map = attn_map[:, :, 0].reshape(H, W) | |
eot_map = attn_map[:, :, -1].reshape(H, W) | |
norm_sot_map = (1 - sot_map) / (1 - sot_map).max() | |
norm_eot_map = eot_map / eot_map.max() | |
pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map | |
total_fg_mask = total_fg_map | |
fg_map = pad_map * total_fg_mask | |
bce_loss = F.binary_cross_entropy(torch.sigmoid(pad_map.to(torch.float16).reshape(-1)), fg_map.to(torch.float16).reshape(-1)) | |
pad_loss += bce_loss | |
loss += (1 - sum_in / (sum_in + sum_out)) ** 2 | |
return loss + alpha * pad_loss | |
def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions): | |
loss = 0 | |
object_number = len(bboxes) | |
if object_number == 0: | |
return torch.tensor(0).float().cuda() | |
for attn_map_integrated in attn_maps_mid: | |
attn_map = attn_map_integrated.chunk(2)[1] | |
# | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
for obj_idx in range(object_number): | |
obj_loss = 0 | |
mask = torch.zeros(size=(H, W)).cuda() | |
for obj_box in bboxes[obj_idx]: | |
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
mask[y_min: y_max, x_min: x_max] = 1 | |
for obj_position in object_positions[obj_idx]: | |
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1) | |
obj_loss += torch.mean((1 - activation_value) ** 2) | |
loss += (obj_loss/len(object_positions[obj_idx])) | |
# compute loss on padding tokens | |
# activation_value = torch.zeros(size=(b, )).cuda() | |
# for obj_idx in range(object_number): | |
# bbox = bboxes[obj_idx] | |
# ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1) | |
# activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
# int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
# | |
# loss += torch.mean((1 - activation_value) ** 2) | |
for attn_map_integrated in attn_maps_up[0]: | |
attn_map = attn_map_integrated.chunk(2)[1] | |
# | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
for obj_idx in range(object_number): | |
obj_loss = 0 | |
mask = torch.zeros(size=(H, W)).cuda() | |
for obj_box in bboxes[obj_idx]: | |
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
mask[y_min: y_max, x_min: x_max] = 1 | |
for obj_position in object_positions[obj_idx]: | |
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
# ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W) | |
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum( | |
dim=-1) | |
obj_loss += torch.mean((1 - activation_value) ** 2) | |
loss += (obj_loss / len(object_positions[obj_idx])) | |
# compute loss on padding tokens | |
# activation_value = torch.zeros(size=(b, )).cuda() | |
# for obj_idx in range(object_number): | |
# bbox = bboxes[obj_idx] | |
# ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1) | |
# activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
# int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
# | |
# loss += torch.mean((1 - activation_value) ** 2) | |
loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid))) | |
return loss |