darshanjani's picture
utils function for inference
3a0062c
raw
history blame contribute delete
No virus
2.84 kB
"""
Implementation of Yolo Loss Function similar to the one in Yolov3 paper,
the difference from what I can tell is I use CrossEntropy for the classes
instead of BinaryCrossEntropy.
"""
import random
import pytorch_lightning as pl
import torch
import torch.nn as nn
from .utils import intersection_over_union
class YoloLoss(pl.LightningModule):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
self.entropy = nn.CrossEntropyLoss()
self.sigmoid = nn.Sigmoid()
# constants for the loss function
self.lambda_class = 1
self.lambda_noobj = 5
self.lambda_obj = 1
self.lambda_box = 1
def forward(self, predictions, target, anchors):
# Check where obj and noobj (we ignore if target == -1)
obj = target[..., 0] == 1
noobj = target[..., 0] == 0
# ======================= #
# FOR NO OBJECT LOSS #
# ======================= #
no_object_loss = self.bce(
(predictions[..., 0:1][noobj]),
(target[..., 0:1][noobj])
)
# ==================== #
# FOR OBJECT LOSS #
# ==================== #
anchors = anchors.reshape(1, 3, 1, 1, 2)
box_preds = torch.cat(
[
self.sigmoid(predictions[..., 1:3]),
torch.exp(predictions[..., 3:5]) * anchors,
],
dim=-1,
)
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
object_loss = self.mse(
self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj]
)
# ======================== #
# FOR BOX COORDINATES #
# ======================== #
predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x,y coordinates
target[..., 3:5] = torch.log(
(1e-16 + target[..., 3:5] / anchors)
) # width, height coordinates
box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])
# ================== #
# FOR CLASS LOSS #
# ================== #
class_loss = self.entropy(
(predictions[..., 5:][obj]),
(target[..., 5][obj].long()),
)
# print("__________________________________")
# print(self.lambda_box * box_loss)
# print(self.lambda_obj * object_loss)
# print(self.lambda_noobj * no_object_loss)
# print(self.lambda_class * class_loss)
# print("\n")
return (
self.lambda_box * box_loss
+ self.lambda_obj * object_loss
+ self.lambda_noobj * no_object_loss
+ self.lambda_class * class_loss
)