Spaces:
Sleeping
Sleeping
Update gligen/ldm/models/diffusion/loss.py
Browse files
gligen/ldm/models/diffusion/loss.py
CHANGED
@@ -566,120 +566,16 @@ def caculate_loss_LoCo_V2(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, o
|
|
566 |
total_fg_mask = total_fg_map
|
567 |
fg_map = pad_map * total_fg_mask
|
568 |
|
569 |
-
bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
|
570 |
|
571 |
-
pad_loss += bce_loss
|
572 |
if sum_in + sum_out == 0:
|
573 |
return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
|
574 |
# loss += (1 - sum_in / (sum_in + sum_out)) ** 2
|
575 |
# print('loss', loss)
|
576 |
# return loss
|
577 |
-
return loss
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
def caculate_loss_LAC(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ):
|
582 |
-
attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
|
583 |
-
all_attn = [attn16]
|
584 |
-
|
585 |
-
|
586 |
-
loss = 0.
|
587 |
-
pad_loss = 0.
|
588 |
-
total_fg_map = torch.zeros(size=(16, 16)).cuda()
|
589 |
-
|
590 |
-
# alpha是pad loss的权重
|
591 |
-
# beta是pad loss内部的权重 例如 beta是SOT的 1 - beta是EOT的
|
592 |
-
alpha = 0.2
|
593 |
-
beta = 0.8
|
594 |
-
|
595 |
-
object_number = len(bboxes)
|
596 |
-
if object_number == 0:
|
597 |
-
return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
|
598 |
-
# attn16 = get_all_attention(attn_maps_down[-1], attn_maps_mid, attn_maps_up[0], 16)
|
599 |
-
# all_attn = [attn16]
|
600 |
-
max_loss = 0
|
601 |
|
602 |
|
603 |
-
for attn_map in all_attn:
|
604 |
-
# print(attn_map.shape)
|
605 |
-
# 原来是[8, 64, 77] 现在只取后一半 attn_map [4, 64, 77]
|
606 |
-
sum_in = 0.
|
607 |
-
sum_out = 0.
|
608 |
-
|
609 |
-
i, j, k = attn_map.shape
|
610 |
-
H = W = i # 在这里是8
|
611 |
-
for obj_idx in range(object_number): # 对于每个box
|
612 |
-
obj_loss = 0
|
613 |
-
mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
|
614 |
-
for obj_box in bboxes[obj_idx]:
|
615 |
|
616 |
-
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
|
617 |
-
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
|
618 |
-
mask[y_min: y_max, x_min: x_max] = 1 # mask是一个全0矩阵 当前物体box的位置设为1
|
619 |
-
total_fg_map[y_min: y_max, x_min: x_max] = 1
|
620 |
|
621 |
-
# 选中obj在token中的位置(即token对应的map) reshape到[4, 16, 16]
|
622 |
-
for obj_position in [object_positions[obj_idx]]: # 注意,object_positions是一个list,形如[[6], [10]] 代表第一个物体在第6个token,第二个物体在第10个token
|
623 |
-
# 选中物体对应位置(例如[6])的map,然后reshape到[4, 16, 16]
|
624 |
-
|
625 |
-
# print(attn_map[:, :, obj_position].shape)
|
626 |
-
ca_map_obj = attn_map[:, :, obj_position].sum(-1)
|
627 |
-
|
628 |
-
print(ca_map_obj.shape)
|
629 |
-
if smooth_att:
|
630 |
-
smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
|
631 |
-
input = F.pad(ca_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
|
632 |
-
ca_map_obj = smoothing(input).squeeze(0).squeeze(0)
|
633 |
-
|
634 |
-
ca_map_obj = ca_map_obj.reshape(H, W)
|
635 |
-
norm_ca_map_obj = ca_map_obj / ca_map_obj.max()
|
636 |
-
|
637 |
-
norm_ca_map_obj = norm_ca_map_obj.reshape(H, W)
|
638 |
-
|
639 |
-
# avg_fg_value = torch.mean(ca_map_obj * mask)
|
640 |
-
# print('avg_fg_value', avg_fg_value)
|
641 |
-
|
642 |
-
sum_in += (norm_ca_map_obj * mask).sum()
|
643 |
-
sum_out += (norm_ca_map_obj * (1 - mask)).sum()
|
644 |
-
|
645 |
-
# 在这里每个物体对应1个box,所以len是1
|
646 |
-
loss += (obj_loss/len(object_positions[obj_idx]))
|
647 |
-
|
648 |
-
# get pad_loss
|
649 |
-
#sot_map = attn_map[:, :, 0].reshape(H, W)
|
650 |
-
#eot_map = attn_map[:, :, -1].reshape(H, W)
|
651 |
-
|
652 |
-
#norm_sot_map = (1 - sot_map) / (1 - sot_map).max()
|
653 |
-
#norm_eot_map = eot_map / eot_map.max()
|
654 |
-
|
655 |
-
|
656 |
-
#pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map
|
657 |
-
|
658 |
-
# pad_map = pad_map.to(torch.float64)
|
659 |
-
|
660 |
-
#total_fg_mask = total_fg_map#.to(torch.float64)
|
661 |
-
#fg_map = pad_map * total_fg_mask
|
662 |
-
|
663 |
-
# print(fg_map.shape)
|
664 |
-
# print(pad_map.shape)
|
665 |
-
# fg_map = torch.sigmoid(fg_map)
|
666 |
-
|
667 |
-
# mse_loss = F.mse_loss(pad_map.reshape(-1), fg_map.reshape(-1))
|
668 |
-
#bce_loss = F.binary_cross_entropy(torch.sigmoid(pad_map.reshape(-1)), fg_map.reshape(-1))
|
669 |
-
# print('mse_loss', mse_loss)
|
670 |
-
# print('bce_loss', bce_loss)
|
671 |
-
#bce_loss = torch.clamp(bce_loss, max=0.99)
|
672 |
-
# pad_loss += mse_loss
|
673 |
-
#pad_loss += bce_loss
|
674 |
-
#pad_loss += (1 - torch.mean((pad_map * total_fg_map).reshape(-1).sum(dim=-1) / pad_map.reshape(-1).sum(dim=-1)) ) **2
|
675 |
-
|
676 |
-
# print('该步优化结束')
|
677 |
-
|
678 |
-
|
679 |
-
loss += (1 - sum_in / (sum_in + sum_out)) ** 2
|
680 |
-
# loss += max_loss
|
681 |
-
# print('loss', loss)
|
682 |
-
# print('pad_loss', alpha * pad_loss)
|
683 |
-
|
684 |
-
|
685 |
-
return loss + alpha * pad_loss
|
|
|
566 |
total_fg_mask = total_fg_map
|
567 |
fg_map = pad_map * total_fg_mask
|
568 |
|
569 |
+
# bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
|
570 |
|
571 |
+
# pad_loss += bce_loss
|
572 |
if sum_in + sum_out == 0:
|
573 |
return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
|
574 |
# loss += (1 - sum_in / (sum_in + sum_out)) ** 2
|
575 |
# print('loss', loss)
|
576 |
# return loss
|
577 |
+
return loss #+ alpha * pad_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
|
579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
|
|
|
|
|
|
|
|
|
581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|