Spaces:
Sleeping
Sleeping
Update gligen/ldm/models/diffusion/loss.py
Browse files
gligen/ldm/models/diffusion/loss.py
CHANGED
@@ -354,7 +354,7 @@ def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, obje
|
|
354 |
# fg_map = torch.sigmoid(fg_map)
|
355 |
|
356 |
# mse_loss = F.mse_loss(pad_map.reshape(-1), fg_map.reshape(-1))
|
357 |
-
bce_loss = F.
|
358 |
# print('mse_loss', mse_loss)
|
359 |
# print('bce_loss', bce_loss)
|
360 |
#bce_loss = torch.clamp(bce_loss, max=0.99)
|
@@ -566,7 +566,7 @@ def caculate_loss_LoCo_V2(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, o
|
|
566 |
total_fg_mask = total_fg_map
|
567 |
fg_map = pad_map * total_fg_mask
|
568 |
|
569 |
-
bce_loss = F.
|
570 |
|
571 |
pad_loss += bce_loss
|
572 |
if sum_in + sum_out == 0:
|
|
|
354 |
# fg_map = torch.sigmoid(fg_map)
|
355 |
|
356 |
# mse_loss = F.mse_loss(pad_map.reshape(-1), fg_map.reshape(-1))
|
357 |
+
bce_loss = F.binary_cross_entropy_with_logits(pad_map.reshape(-1), fg_map.reshape(-1))
|
358 |
# print('mse_loss', mse_loss)
|
359 |
# print('bce_loss', bce_loss)
|
360 |
#bce_loss = torch.clamp(bce_loss, max=0.99)
|
|
|
566 |
total_fg_mask = total_fg_map
|
567 |
fg_map = pad_map * total_fg_mask
|
568 |
|
569 |
+
bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
|
570 |
|
571 |
pad_loss += bce_loss
|
572 |
if sum_in + sum_out == 0:
|