Pusheen commited on
Commit
4c96567
1 Parent(s): 2fe6c8a

Update gligen/ldm/models/diffusion/loss.py

Browse files
gligen/ldm/models/diffusion/loss.py CHANGED
@@ -354,7 +354,7 @@ def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, obje
354
  # fg_map = torch.sigmoid(fg_map)
355
 
356
  # mse_loss = F.mse_loss(pad_map.reshape(-1), fg_map.reshape(-1))
357
- bce_loss = F.binary_cross_entropy(torch.sigmoid(pad_map.reshape(-1)), fg_map.reshape(-1))
358
  # print('mse_loss', mse_loss)
359
  # print('bce_loss', bce_loss)
360
  #bce_loss = torch.clamp(bce_loss, max=0.99)
@@ -566,7 +566,7 @@ def caculate_loss_LoCo_V2(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, o
566
  total_fg_mask = total_fg_map
567
  fg_map = pad_map * total_fg_mask
568
 
569
- bce_loss = F.binary_cross_entropy(torch.sigmoid(pad_map.to(torch.float16).reshape(-1)), fg_map.to(torch.float16).reshape(-1))
570
 
571
  pad_loss += bce_loss
572
  if sum_in + sum_out == 0:
 
354
  # fg_map = torch.sigmoid(fg_map)
355
 
356
  # mse_loss = F.mse_loss(pad_map.reshape(-1), fg_map.reshape(-1))
357
+ bce_loss = F.binary_cross_entropy_with_logits(pad_map.reshape(-1), fg_map.reshape(-1))
358
  # print('mse_loss', mse_loss)
359
  # print('bce_loss', bce_loss)
360
  #bce_loss = torch.clamp(bce_loss, max=0.99)
 
566
  total_fg_mask = total_fg_map
567
  fg_map = pad_map * total_fg_mask
568
 
569
+ bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
570
 
571
  pad_loss += bce_loss
572
  if sum_in + sum_out == 0: