|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dis import dis |
|
import torch |
|
from torch import nn |
|
import torch.distributed as dist |
|
from torch.functional import Tensor |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
def compound_loss(coe, output_feature, image:Tensor, output_label, targets, criterion_bce, criterion_ce, epoch): |
|
f_coe, c_coe = coe |
|
image.clamp_(0.01, 0.99) |
|
multi_loss = [] |
|
for i, feature in enumerate(output_feature): |
|
ratio_f = 1 - i / len(output_feature) |
|
ratio_c = (i+1) / (len(output_label)) |
|
|
|
ihx = criterion_bce(feature, image) * ratio_f * f_coe |
|
ihy = criterion_ce(output_label[i], targets) * ratio_c * c_coe |
|
|
|
|
|
multi_loss.append(ihx + ihy) |
|
|
|
multi_loss.append(criterion_ce(output_label[-1], targets)) |
|
|
|
loss = torch.sum(torch.stack(multi_loss), dim=0) |
|
|
|
return loss, multi_loss |
|
|