Spaces:
Sleeping
Sleeping
Update gligen/ldm/models/diffusion/loss.py
Browse files
gligen/ldm/models/diffusion/loss.py
CHANGED
@@ -554,27 +554,27 @@ def caculate_loss_LoCo_V2(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, o
|
|
554 |
|
555 |
loss += (obj_loss) # /len(object_positions[obj_idx])
|
556 |
|
557 |
-
|
558 |
-
|
559 |
|
560 |
-
|
561 |
-
|
562 |
|
563 |
|
564 |
-
|
565 |
|
566 |
-
|
567 |
-
|
568 |
|
569 |
-
|
570 |
|
571 |
-
|
572 |
if sum_in + sum_out == 0:
|
573 |
return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
|
574 |
# loss += (1 - sum_in / (sum_in + sum_out)) ** 2
|
575 |
# print('loss', loss)
|
576 |
# return loss
|
577 |
-
return loss
|
578 |
|
579 |
|
580 |
|
|
|
554 |
|
555 |
loss += (obj_loss) # /len(object_positions[obj_idx])
|
556 |
|
557 |
+
sot_map = attn_map[:, :, 0].reshape(H, W)
|
558 |
+
eot_map = attn_map[:, :, -1].reshape(H, W)
|
559 |
|
560 |
+
norm_sot_map = (1 - sot_map) / (1 - sot_map).max()
|
561 |
+
norm_eot_map = eot_map / eot_map.max()
|
562 |
|
563 |
|
564 |
+
pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map
|
565 |
|
566 |
+
total_fg_mask = total_fg_map
|
567 |
+
fg_map = pad_map * total_fg_mask
|
568 |
|
569 |
+
bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
|
570 |
|
571 |
+
pad_loss += bce_loss
|
572 |
if sum_in + sum_out == 0:
|
573 |
return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
|
574 |
# loss += (1 - sum_in / (sum_in + sum_out)) ** 2
|
575 |
# print('loss', loss)
|
576 |
# return loss
|
577 |
+
return loss + alpha * pad_loss
|
578 |
|
579 |
|
580 |
|