Pusheen commited on
Commit
f78dc6a
1 Parent(s): 4c96567

Update gligen/ldm/models/diffusion/loss.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/loss.py +3 -107
gligen/ldm/models/diffusion/loss.py CHANGED
@@ -566,120 +566,16 @@ def caculate_loss_LoCo_V2(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, o
566
  total_fg_mask = total_fg_map
567
  fg_map = pad_map * total_fg_mask
568
 
569
- bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
570
 
571
- pad_loss += bce_loss
572
  if sum_in + sum_out == 0:
573
  return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
574
  # loss += (1 - sum_in / (sum_in + sum_out)) ** 2
575
  # print('loss', loss)
576
  # return loss
577
- return loss + alpha * pad_loss
578
-
579
-
580
-
581
- def caculate_loss_LAC(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ):
582
- attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
583
- all_attn = [attn16]
584
-
585
-
586
- loss = 0.
587
- pad_loss = 0.
588
- total_fg_map = torch.zeros(size=(16, 16)).cuda()
589
-
590
- # alpha是pad loss的权重
591
- # beta是pad loss内部的权重 例如 beta是SOT的 1 - beta是EOT的
592
- alpha = 0.2
593
- beta = 0.8
594
-
595
- object_number = len(bboxes)
596
- if object_number == 0:
597
- return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
598
- # attn16 = get_all_attention(attn_maps_down[-1], attn_maps_mid, attn_maps_up[0], 16)
599
- # all_attn = [attn16]
600
- max_loss = 0
601
 
602
 
603
- for attn_map in all_attn:
604
- # print(attn_map.shape)
605
- # 原来是[8, 64, 77] 现在只取后一半 attn_map [4, 64, 77]
606
- sum_in = 0.
607
- sum_out = 0.
608
-
609
- i, j, k = attn_map.shape
610
- H = W = i # 在这里是8
611
- for obj_idx in range(object_number): # 对于每个box
612
- obj_loss = 0
613
- mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
614
- for obj_box in bboxes[obj_idx]:
615
 
616
- x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
617
- int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
618
- mask[y_min: y_max, x_min: x_max] = 1 # mask是一个全0矩阵 当前物体box的位置设为1
619
- total_fg_map[y_min: y_max, x_min: x_max] = 1
620
 
621
- # 选中obj在token中的位置(即token对应的map) reshape到[4, 16, 16]
622
- for obj_position in [object_positions[obj_idx]]: # 注意,object_positions是一个list,形如[[6], [10]] 代表第一个物体在第6个token,第二个物体在第10个token
623
- # 选中物体对应位置(例如[6])的map,然后reshape到[4, 16, 16]
624
-
625
- # print(attn_map[:, :, obj_position].shape)
626
- ca_map_obj = attn_map[:, :, obj_position].sum(-1)
627
-
628
- print(ca_map_obj.shape)
629
- if smooth_att:
630
- smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
631
- input = F.pad(ca_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
632
- ca_map_obj = smoothing(input).squeeze(0).squeeze(0)
633
-
634
- ca_map_obj = ca_map_obj.reshape(H, W)
635
- norm_ca_map_obj = ca_map_obj / ca_map_obj.max()
636
-
637
- norm_ca_map_obj = norm_ca_map_obj.reshape(H, W)
638
-
639
- # avg_fg_value = torch.mean(ca_map_obj * mask)
640
- # print('avg_fg_value', avg_fg_value)
641
-
642
- sum_in += (norm_ca_map_obj * mask).sum()
643
- sum_out += (norm_ca_map_obj * (1 - mask)).sum()
644
-
645
- # 在这里每个物体对应1个box,所以len是1
646
- loss += (obj_loss/len(object_positions[obj_idx]))
647
-
648
- # get pad_loss
649
- #sot_map = attn_map[:, :, 0].reshape(H, W)
650
- #eot_map = attn_map[:, :, -1].reshape(H, W)
651
-
652
- #norm_sot_map = (1 - sot_map) / (1 - sot_map).max()
653
- #norm_eot_map = eot_map / eot_map.max()
654
-
655
-
656
- #pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map
657
-
658
- # pad_map = pad_map.to(torch.float64)
659
-
660
- #total_fg_mask = total_fg_map#.to(torch.float64)
661
- #fg_map = pad_map * total_fg_mask
662
-
663
- # print(fg_map.shape)
664
- # print(pad_map.shape)
665
- # fg_map = torch.sigmoid(fg_map)
666
-
667
- # mse_loss = F.mse_loss(pad_map.reshape(-1), fg_map.reshape(-1))
668
- #bce_loss = F.binary_cross_entropy(torch.sigmoid(pad_map.reshape(-1)), fg_map.reshape(-1))
669
- # print('mse_loss', mse_loss)
670
- # print('bce_loss', bce_loss)
671
- #bce_loss = torch.clamp(bce_loss, max=0.99)
672
- # pad_loss += mse_loss
673
- #pad_loss += bce_loss
674
- #pad_loss += (1 - torch.mean((pad_map * total_fg_map).reshape(-1).sum(dim=-1) / pad_map.reshape(-1).sum(dim=-1)) ) **2
675
-
676
- # print('该步优化结束')
677
-
678
-
679
- loss += (1 - sum_in / (sum_in + sum_out)) ** 2
680
- # loss += max_loss
681
- # print('loss', loss)
682
- # print('pad_loss', alpha * pad_loss)
683
-
684
-
685
- return loss + alpha * pad_loss
 
566
  total_fg_mask = total_fg_map
567
  fg_map = pad_map * total_fg_mask
568
 
569
+ # bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
570
 
571
+ # pad_loss += bce_loss
572
  if sum_in + sum_out == 0:
573
  return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
574
  # loss += (1 - sum_in / (sum_in + sum_out)) ** 2
575
  # print('loss', loss)
576
  # return loss
577
+ return loss #+ alpha * pad_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
 
579
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
 
 
 
 
581