import math import torch from ldm.models.diffusion.gaussian_smoothing import GaussianSmoothing from torch.nn import functional as F from torchvision.utils import save_image def loss_one_att_outside(attn_map,bboxes, object_positions,t): # loss = torch.tensor(0).to('cuda') loss = 0 object_number = len(bboxes) b, i, j = attn_map.shape H = W = int(math.sqrt(i)) # if t== 20: import pdb; pdb.set_trace() for obj_idx in range(object_number): for obj_box in bboxes[obj_idx]: mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) mask[y_min: y_max, x_min: x_max] = 1. mask_out = 1. - mask index = (mask == 1.).nonzero(as_tuple=False) index_in_key = index[:,0]* H + index[:, 1] att_box = torch.zeros_like(attn_map) att_box[:,index_in_key,:] = attn_map[:,index_in_key,:] att_box = att_box.sum(axis=1) / index_in_key.shape[0] att_box = att_box.reshape(-1, H, H) activation_value = (att_box* mask_out).reshape(b, -1).sum(dim=-1) #/ att_box.reshape(b, -1).sum(dim=-1) loss += torch.mean(activation_value) return loss / object_number def caculate_loss_self_att(self_first, self_second, self_third, bboxes, object_positions, t, list_res=[256], smooth_att = True,sigma=0.5,kernel_size=3 ): all_attn = get_all_self_att(self_first, self_second, self_third) cnt = 0 total_loss = 0 for res in list_res: attn_maps = all_attn[res] for attn in attn_maps: total_loss += loss_one_att_outside(attn, bboxes, object_positions,t) cnt += 1 return total_loss /cnt def get_all_self_att(self_first, self_second, self_third): result = {256:[], 1024:[], 4096:[], 64:[], 94:[],1054:[] ,286:[],4126:[] } # import pdb; pdb.set_trace() all_att = [self_first, self_second, self_third] for self_att in all_att: for att in self_att: if att != []: temp = att[0] for attn_map in temp: current_res = attn_map.shape[1] # print(current_res) result[current_res].append(attn_map) return result def get_all_attention(attn_maps_mid, attn_maps_up , attn_maps_down, res): result = [] # print('map from up *********************************************') for attn_map_integrated in attn_maps_up: if attn_map_integrated == []: continue attn_map = attn_map_integrated[0][0] # print(attn_map.shape) b, i, j = attn_map.shape H = W = int(math.sqrt(i)) # print(H) if H == res: # print(attn_map.shape) result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] )) # print('map from mid *********************************************') for attn_map_integrated in attn_maps_mid: # for attn_map_integrated in attn_maps_mid: attn_map = attn_map_integrated[0] # print(attn_map.shape) b, i, j = attn_map.shape H = W = int(math.sqrt(i)) # print(H) if (H==res): # print(attn_map.shape) result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] )) # import pdb; pdb.set_trace() # print('map from down *********************************************') for attn_map_integrated in attn_maps_down: if attn_map_integrated == []: continue attn_map = attn_map_integrated[0][0] # print(attn_map.shape) if attn_map == []: continue b, i, j = attn_map.shape H = W = int(math.sqrt(i)) # print(H) if (H==res): # print(attn_map.shape) result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] )) # for _map in result: # print(_map.shape) result = torch.cat(result, dim=0) # print(result.shape) result = result.sum(0) / result.shape[0] # print(result.shape) return result def get_all_attention_64(attn_maps_mid, attn_maps_up , attn_maps_down, res): result = [] # print('map from up *********************************************') for attn_map_integrated in attn_maps_up: if attn_map_integrated == []: continue attn_map = attn_map_integrated[0][0] # print(attn_map.shape) b, i, j = attn_map.shape H = W = int(math.sqrt(i)) # print(H) if H == res: # print(attn_map.shape) item = attn_map.reshape(-1, res, res, attn_map.shape[-1] ) item = item.permute(0, 3, 1, 2) item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) result.append(item) # result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] )) # print('map from mid *********************************************') for attn_map_integrated in attn_maps_mid: # for attn_map_integrated in attn_maps_mid: attn_map = attn_map_integrated[0] # print(attn_map.shape) b, i, j = attn_map.shape H = W = int(math.sqrt(i)) # print(H) if (H==8): item = attn_map.reshape(-1, 8, 8, attn_map.shape[-1] ) item = item.permute(0, 3, 1, 2) item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) result.append(item) # result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] )) # import pdb; pdb.set_trace() # print('map from down *********************************************') for attn_map_integrated in attn_maps_down: if attn_map_integrated == []: continue attn_map = attn_map_integrated[0][0] # print(attn_map.shape) if attn_map == []: continue b, i, j = attn_map.shape H = W = int(math.sqrt(i)) # print(H) if (H==res): item = attn_map.reshape(-1, res, res, attn_map.shape[-1] ) item = item.permute(0, 3, 1, 2) item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) result.append(item) # for _map in result: # print(_map.shape) result = torch.cat(result, dim=0) # print(result.shape) result = result.sum(0) / result.shape[0] # print(result.shape) return result def caculate_loss_att_fixed_cnt(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ): attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res) # attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32) # attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64) # attn8 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 8) all_attn = [attn16] obj_number = len(bboxes) total_loss = 0 # import pdb; pdb.set_trace() for attn in all_attn[0:1]: # print(attn.shape) attn_text = attn[:, :, 1:-1] attn_text *= 100 attn_text = torch.nn.functional.softmax(attn_text, dim=-1) current_res = attn.shape[0] H = W = current_res # if t == 49: import pdb; pdb.set_trace() # 对于每一个物体 for obj_idx in range(obj_number): num_boxes= 0 # 对于该物体 对应的 每一个box 一般就一个 for obj_position in object_positions[obj_idx]: true_obj_position = obj_position - 1 # 取出该物体该box对应的attention map att_map_obj = attn_text[:,:, true_obj_position] print(att_map_obj.shape) if smooth_att: smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() input = F.pad(att_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') att_map_obj = smoothing(input).squeeze(0).squeeze(0) print('after', att_map_obj.shape) other_att_map_obj = att_map_obj.clone() att_copy = att_map_obj.clone() for obj_box in bboxes[obj_idx]: # print('obj_box', type(obj_box)) x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) # 取得这张map上 当前box的最大值 if att_map_obj[y_min: y_max, x_min: x_max].numel() == 0: max_inside=1. else: max_inside = att_map_obj[y_min: y_max, x_min: x_max].max() total_loss += 1. - max_inside # find max outside the box, find in the other boxes att_copy[y_min: y_max, x_min: x_max] = 0. other_att_map_obj[y_min: y_max, x_min: x_max] = 0. for obj_outside in range(obj_number): if obj_outside != obj_idx: for obj_out_box in bboxes[obj_outside]: x_min_out, y_min_out, x_max_out, y_max_out = int(obj_out_box[0] * W), \ int(obj_out_box[1] * H), int(obj_out_box[2] * W), int(obj_out_box[3] * H) # 取得这张map上 其他box中的最大值 if other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].numel() == 0: max_outside_one= 0 else: max_outside_one = other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].max() # max_outside = max(max_outside,max_outside_one ) # 把所有box都置0 att_copy[y_min_out: y_max_out, x_min_out: x_max_out] = 0. total_loss += max_outside_one max_background = att_copy.max() total_loss += len(bboxes[obj_idx]) *max_background /2. return total_loss/obj_number def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = False,sigma=0.5,kernel_size=3 ): attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res) # attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32) # attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64) # attn8 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 8) all_attn = [attn16] loss = 0. pad_loss = 0. total_fg_map = torch.zeros(size=(16, 16)).cuda() # alpha是pad loss的权重 # beta是pad loss内部的权重 例如 beta是SOT的 1 - beta是EOT的 alpha = 0.2 beta = 0.8 object_number = len(bboxes) if object_number == 0: return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float() # attn16 = get_all_attention(attn_maps_down[-1], attn_maps_mid, attn_maps_up[0], 16) # all_attn = [attn16] max_loss = 0 for attn_map in all_attn: # print(attn_map.shape) # 原来是[8, 64, 77] 现在只取后一半 attn_map [4, 64, 77] sum_in = 0. sum_out = 0. i, j, k = attn_map.shape H = W = i # 在这里是8 for obj_idx in range(object_number): # 对于每个box obj_loss = 0 mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) for obj_box in bboxes[obj_idx]: x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) mask[y_min: y_max, x_min: x_max] = 1 # mask是一个全0矩阵 当前物体box的位置设为1 total_fg_map[y_min: y_max, x_min: x_max] = 1 # 选中obj在token中的位置(即token对应的map) reshape到[4, 16, 16] for obj_position in [object_positions[obj_idx]]: # 注意,object_positions是一个list,形如[[6], [10]] 代表第一个物体在第6个token,第二个物体在第10个token # 选中物体对应位置(例如[6])的map,然后reshape到[4, 16, 16] # print(attn_map[:, :, obj_position].shape) ca_map_obj = attn_map[:, :, obj_position].mean(-1) # print(ca_map_obj.shape) if smooth_att: smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() input = F.pad(ca_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') ca_map_obj = smoothing(input).squeeze(0).squeeze(0) ca_map_obj = ca_map_obj.reshape(H, W) norm_ca_map_obj = ca_map_obj / ca_map_obj.max() # if smooth_attn: # smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).cuda() # input = F.pad(norm_ca_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') # ca_map_obj = smoothing(input).squeeze(0).squeeze(0) norm_ca_map_obj = norm_ca_map_obj.reshape(H, W) # avg_fg_value = torch.mean(ca_map_obj * mask) # print('avg_fg_value', avg_fg_value) sum_in += (norm_ca_map_obj * mask).sum() sum_out += (norm_ca_map_obj * (1 - mask)).sum() # obj_loss += torch.mean((1 - activation_value) ** 2) # # SOTR loss # ca_map_obj = (1 - attn_map[:, :, 0]).reshape(H, W) # if (1 - attn_map[:, :, obj_position].max()) > max_loss: # max_loss = (1 - attn_map[:, :, obj_position].max()) # ca_map_obj = (1 - attn_map[:, :, 0]).reshape(H, W) # if smooth_attn: # smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).cuda() # input = F.pad(ca_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') # ca_map_obj = smoothing(input).squeeze(0).squeeze(0) # # ca_map_obj *= 100 # # ca_map_obj = torch.nn.functional.softmax(ca_map_obj, dim=-1) # activation_value = (ca_map_obj * mask).reshape(-1).sum(dim=-1) / ca_map_obj.reshape(-1).sum(dim=-1) # obj_loss += torch.mean((1 - activation_value) ** 2) # obj_loss 就是标量了 tensor(0.3547, device='cuda:0', grad_fn=) # 在这里每个物体对应1个box,所以len是1 loss += (obj_loss/len(object_positions[obj_idx])) # get pad_loss sot_map = attn_map[:, :, 0].reshape(H, W) eot_map = attn_map[:, :, -1].reshape(H, W) norm_sot_map = (1 - sot_map) / (1 - sot_map).max() norm_eot_map = eot_map / eot_map.max() pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map # pad_map = pad_map.to(torch.float64) total_fg_mask = total_fg_map#.to(torch.float64) fg_map = pad_map * total_fg_mask # print(fg_map.shape) # print(pad_map.shape) # fg_map = torch.sigmoid(fg_map) # mse_loss = F.mse_loss(pad_map.reshape(-1), fg_map.reshape(-1)) bce_loss = F.binary_cross_entropy_with_logits(pad_map.reshape(-1), fg_map.reshape(-1)) # print('mse_loss', mse_loss) # print('bce_loss', bce_loss) #bce_loss = torch.clamp(bce_loss, max=0.99) # pad_loss += mse_loss pad_loss += bce_loss #pad_loss += (1 - torch.mean((pad_map * total_fg_map).reshape(-1).sum(dim=-1) / pad_map.reshape(-1).sum(dim=-1)) ) **2 # print('该步优化结束') loss += (1 - sum_in / (sum_in + sum_out)) ** 2 # loss += max_loss # print('loss', loss) # print('pad_loss', alpha * pad_loss) return loss + alpha * pad_loss def caculate_loss_LoCo_64(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ): attn16 = get_all_attention_64(attn_maps_mid, attn_maps_up, attn_maps_down, res) all_attn = [attn16] loss = 0. pad_loss = 0. total_fg_map = torch.zeros(size=(64, 64)).cuda() # alpha是pad loss的权重 # beta是pad loss内部的权重 例如 beta是SOT的 1 - beta是EOT的 alpha = 0.2 beta = 0.8 object_number = len(bboxes) if object_number == 0: return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float() # attn16 = get_all_attention(attn_maps_down[-1], attn_maps_mid, attn_maps_up[0], 16) # all_attn = [attn16] max_loss = 0 for attn_map in all_attn: # print(attn_map.shape) # 原来是[8, 64, 77] 现在只取后一半 attn_map [4, 64, 77] sum_in = 0. sum_out = 0. i, j, k = attn_map.shape H = W = i # 在这里是8 for obj_idx in range(object_number): # 对于每个box obj_loss = 0 mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) for obj_box in bboxes[obj_idx]: x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) mask[y_min: y_max, x_min: x_max] = 1 # mask是一个全0矩阵 当前物体box的位置设为1 total_fg_map[y_min: y_max, x_min: x_max] = 1 # 选中obj在token中的位置(即token对应的map) reshape到[4, 16, 16] for obj_position in [object_positions[obj_idx]]: # 注意,object_positions是一个list,形如[[6], [10]] 代表第一个物体在第6个token,第二个物体在第10个token # 选中物体对应位置(例如[6])的map,然后reshape到[4, 16, 16] # print(attn_map[:, :, obj_position].shape) ca_map_obj = attn_map[:, :, obj_position].sum(-1) # print(ca_map_obj.shape) if smooth_att: smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() input = F.pad(ca_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') ca_map_obj = smoothing(input).squeeze(0).squeeze(0) ca_map_obj = ca_map_obj.reshape(H, W) norm_ca_map_obj = ca_map_obj / ca_map_obj.max() norm_ca_map_obj = norm_ca_map_obj.reshape(H, W) sum_in += (norm_ca_map_obj * mask).sum() sum_out += (norm_ca_map_obj * (1 - mask)).sum() # 在这里每个物体对应1个box,所以len是1 loss += (obj_loss/len(object_positions[obj_idx])) # get pad_loss sot_map = attn_map[:, :, 0].reshape(H, W) eot_map = attn_map[:, :, -1].reshape(H, W) norm_sot_map = (1 - sot_map) / (1 - sot_map).max() norm_eot_map = eot_map / eot_map.max() pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map # pad_map = pad_map.to(torch.float64) total_fg_mask = total_fg_map#.to(torch.float64) fg_map = pad_map * total_fg_mask # print(fg_map.shape) # print(pad_map.shape) # fg_map = torch.sigmoid(fg_map) # mse_loss = F.mse_loss(pad_map.reshape(-1), fg_map.reshape(-1)) bce_loss = F.binary_cross_entropy(torch.sigmoid(pad_map.reshape(-1)), fg_map.reshape(-1)) # print('mse_loss', mse_loss) # print('bce_loss', bce_loss) #bce_loss = torch.clamp(bce_loss, max=0.99) # pad_loss += mse_loss pad_loss += bce_loss #pad_loss += (1 - torch.mean((pad_map * total_fg_map).reshape(-1).sum(dim=-1) / pad_map.reshape(-1).sum(dim=-1)) ) **2 # print('该步优化结束') loss += (1 - sum_in / (sum_in + sum_out)) ** 2 # loss += max_loss # print('loss', loss) # print('pad_loss', alpha * pad_loss) return loss + alpha * pad_loss def caculate_loss_LoCo_V2(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ): attn16 = get_all_attention_64(attn_maps_mid, attn_maps_up, attn_maps_down, res) all_attn = [attn16] loss = 0. pad_loss = 0. total_fg_map = torch.zeros(size=(64, 64)).cuda() alpha = 0.2 beta = 0.8 object_number = len(bboxes) if object_number == 0: return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float() # attn16 = attn_maps # get_all_attention_64(attn_maps_down[-1]+ attn_maps_down[-2], attn_maps_mid, attn_maps_up[0]+attn_maps_up[1], 16) # all_attn = [attn16] max_loss = 0 for attn_map in all_attn: sum_in = 0. sum_out = 0. i, j, k = attn_map.shape H = W = i for obj_idx in range(object_number): obj_loss = 0 mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) for obj_box in bboxes[obj_idx]: x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) mask[y_min: y_max, x_min: x_max] = 1 total_fg_map[y_min: y_max, x_min: x_max] = 1 for obj_position in [object_positions[obj_idx]]: # print('obj_position', obj_position) if len(object_positions[obj_idx]) > 1 : ca_map_obj = attn_map[:, :, obj_position].mean(-1) else: ca_map_obj = attn_map[:, :, obj_position] ca_map_obj = ca_map_obj.reshape(H, W) # norm_attn = (ca_map_obj - ca_map_obj.min()) / (ca_map_obj.max() - ca_map_obj.min()) norm_attn = ca_map_obj / ca_map_obj.max() norm_attn = norm_attn.reshape(H, W) # rev_mask = (1 - mask) # thres = (norm_attn * mask).sum() / mask.sum() / 5 * 2 + ((norm_attn * rev_mask).sum() / rev_mask.sum() / 5 * 3) if rev_mask.sum() != 0 else 0 # thres_image = torch.nn.functional.threshold(norm_attn, thres.item(), 0.0) # thres_image = thres_image / thres_image.max() # rows, cols = torch.where(thres_image > 0.3) # if rows.numel() == 0: # x1 = y1 = x2 = y2 = 0 # else: # x1, y1 = cols.min(), rows.min() # x2, y2 = cols.max(), rows.max() # # x1, y1 = cols.min(), rows.min() # # x2, y2 = cols.max(), rows.max() # mask_MBR = mask.clone() # mask_MBR[y1:y2, x1:x2] = 1 # iou = (mask_MBR * mask).sum() / torch.max(mask_MBR, mask).sum() iou = 0 if iou < 0.85: sum_in = (1 - iou) * (norm_attn * mask).sum() sum_out = (1 - iou) * (norm_attn * (1 - mask)).sum() obj_loss += (1 - sum_in / (sum_in + sum_out)) ** 2 loss += (obj_loss) # /len(object_positions[obj_idx]) sot_map = attn_map[:, :, 0].reshape(H, W) eot_map = attn_map[:, :, -1].reshape(H, W) norm_sot_map = (1 - sot_map) / (1 - sot_map).max() norm_eot_map = eot_map / eot_map.max() pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map total_fg_mask = total_fg_map fg_map = pad_map * total_fg_mask bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1)) pad_loss += bce_loss if sum_in + sum_out == 0: return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float() # loss += (1 - sum_in / (sum_in + sum_out)) ** 2 # print('loss', loss) # return loss return loss + alpha * pad_loss