Pusheen commited on
Commit
e5dd0f3
1 Parent(s): 661ec7d

Update gligen/ldm/models/diffusion/loss.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/loss.py +10 -10
gligen/ldm/models/diffusion/loss.py CHANGED
@@ -554,27 +554,27 @@ def caculate_loss_LoCo_V2(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, o
554
 
555
  loss += (obj_loss) # /len(object_positions[obj_idx])
556
 
557
- # sot_map = attn_map[:, :, 0].reshape(H, W)
558
- # eot_map = attn_map[:, :, -1].reshape(H, W)
559
 
560
- # norm_sot_map = (1 - sot_map) / (1 - sot_map).max()
561
- # norm_eot_map = eot_map / eot_map.max()
562
 
563
 
564
- # pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map
565
 
566
- # total_fg_mask = total_fg_map
567
- # fg_map = pad_map * total_fg_mask
568
 
569
- # bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
570
 
571
- # pad_loss += bce_loss
572
  if sum_in + sum_out == 0:
573
  return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
574
  # loss += (1 - sum_in / (sum_in + sum_out)) ** 2
575
  # print('loss', loss)
576
  # return loss
577
- return loss #+ alpha * pad_loss
578
 
579
 
580
 
 
554
 
555
  loss += (obj_loss) # /len(object_positions[obj_idx])
556
 
557
+ sot_map = attn_map[:, :, 0].reshape(H, W)
558
+ eot_map = attn_map[:, :, -1].reshape(H, W)
559
 
560
+ norm_sot_map = (1 - sot_map) / (1 - sot_map).max()
561
+ norm_eot_map = eot_map / eot_map.max()
562
 
563
 
564
+ pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map
565
 
566
+ total_fg_mask = total_fg_map
567
+ fg_map = pad_map * total_fg_mask
568
 
569
+ bce_loss = F.binary_cross_entropy_with_logits(pad_map.to(torch.float16).reshape(-1), fg_map.to(torch.float16).reshape(-1))
570
 
571
+ pad_loss += bce_loss
572
  if sum_in + sum_out == 0:
573
  return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
574
  # loss += (1 - sum_in / (sum_in + sum_out)) ** 2
575
  # print('loss', loss)
576
  # return loss
577
+ return loss + alpha * pad_loss
578
 
579
 
580