Spaces:
Sleeping
Sleeping
import torch | |
import math | |
def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions): | |
loss = 0 | |
object_number = len(bboxes) | |
if object_number == 0: | |
return torch.tensor(0).float().cuda() | |
for attn_map_integrated in attn_maps_mid: | |
attn_map = attn_map_integrated.chunk(2)[1] | |
# | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
for obj_idx in range(object_number): | |
obj_loss = 0 | |
mask = torch.zeros(size=(H, W)).cuda() | |
for obj_box in bboxes[obj_idx]: | |
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
mask[y_min: y_max, x_min: x_max] = 1 | |
for obj_position in object_positions[obj_idx]: | |
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1) | |
obj_loss += torch.mean((1 - activation_value) ** 2) | |
loss += (obj_loss/len(object_positions[obj_idx])) | |
# compute loss on padding tokens | |
# activation_value = torch.zeros(size=(b, )).cuda() | |
# for obj_idx in range(object_number): | |
# bbox = bboxes[obj_idx] | |
# ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1) | |
# activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
# int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
# | |
# loss += torch.mean((1 - activation_value) ** 2) | |
for attn_map_integrated in attn_maps_up[0]: | |
attn_map = attn_map_integrated.chunk(2)[1] | |
# | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
for obj_idx in range(object_number): | |
obj_loss = 0 | |
mask = torch.zeros(size=(H, W)).cuda() | |
for obj_box in bboxes[obj_idx]: | |
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
mask[y_min: y_max, x_min: x_max] = 1 | |
for obj_position in object_positions[obj_idx]: | |
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
# ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W) | |
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum( | |
dim=-1) | |
obj_loss += torch.mean((1 - activation_value) ** 2) | |
loss += (obj_loss / len(object_positions[obj_idx])) | |
# compute loss on padding tokens | |
# activation_value = torch.zeros(size=(b, )).cuda() | |
# for obj_idx in range(object_number): | |
# bbox = bboxes[obj_idx] | |
# ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1) | |
# activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
# int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
# | |
# loss += torch.mean((1 - activation_value) ** 2) | |
loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid))) | |
return loss |