darshanjani's picture
utils function for inference
3a0062c
raw
history blame
No virus
7.81 kB
"""
Implementation of YOLOv3 architecture
"""
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from . import config
from .loss import YoloLoss
model_config = [
(32, 3, 1),
(64, 3, 2),
["B", 1],
(128, 3, 2),
["B", 2],
(256, 3, 2),
["B", 8],
(512, 3, 2),
["B", 8],
(1024, 3, 2),
["B", 4], # darknet 53 ends here
(512, 1, 1),
(1024, 3, 1),
"S",
(256, 1, 1),
"U",
(256, 1, 1),
(512, 3, 1),
"S",
(128, 1, 1),
"U",
(128, 1, 1),
(256, 3, 1),
"S"
]
class CNNBlock(pl.LightningModule):
def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
self.leaky = nn.LeakyReLU(0.1)
self.use_bn_act = bn_act
def forward(self, x):
if self.use_bn_act:
return self.leaky(self.bn((self.conv(x))))
else:
return self.conv(x)
class ResidualBlock(pl.LightningModule):
def __init__(self, channels, use_residual=True, num_repeats=1):
super().__init__()
self.layers = nn.ModuleList()
for repeat in range(num_repeats):
self.layers += [
nn.Sequential(
CNNBlock(channels, channels//2, kernel_size=1),
CNNBlock(channels//2, channels, kernel_size=3, padding=1)
)
]
self.use_residual = use_residual
self.num_repeats = num_repeats
def forward(self, x):
for layer in self.layers:
if self.use_residual:
x = x + layer(x)
else:
x = layer(x)
return x
class ScalePrediction(pl.LightningModule):
def __init__(self, in_channels, num_classes):
super().__init__()
self.pred = nn.Sequential(
CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
CNNBlock(2 * in_channels, (num_classes + 5) * 3, kernel_size=1, bn_act=False)
)
self.num_classes = num_classes
def forward(self, x):
return (
self.pred(x).
reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]).
permute(0, 1, 3, 4, 2)
)
class YOLOv3(pl.LightningModule):
def __init__(self, in_channels=3, num_classes=20):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.layers = self._create_conv_layers()
self.scaled_anchors = (
torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) # ?
).to(config.DEVICE)
self.learning_rate = config.LEARNING_RATE
self.weight_decay = config.WEIGHT_DECAY
self.best_lr = 1e-3 ## ?
def forward(self, x): # ?
outputs = [] # for each scale
route_connections = []
for layer in self.layers:
if isinstance(layer, ScalePrediction):
outputs.append(layer(x))
continue
x = layer(x)
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
route_connections.append(x)
elif isinstance(layer, nn.Upsample):
x = torch.cat([x, route_connections[-1]], dim=1)
route_connections.pop()
return outputs
def _create_conv_layers(self):
layers = nn.ModuleList()
in_channels = self.in_channels
for module in model_config:
if isinstance(module, tuple):
out_channels, kernel_size, stride = module
layers.append(
CNNBlock(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1 if kernel_size==3 else 0)
)
in_channels = out_channels
elif isinstance(module, list):
num_repeats = module[1]
layers.append(
ResidualBlock(in_channels, num_repeats=num_repeats)
)
elif isinstance(module, str):
if module == "S":
layers += [
ResidualBlock(in_channels, use_residual=False, num_repeats=1),
CNNBlock(in_channels, in_channels//2, kernel_size=1),
ScalePrediction(in_channels//2, num_classes=self.num_classes)
]
in_channels = in_channels // 2
elif module == "U":
layers.append(nn.Upsample(scale_factor=2))
in_channels = in_channels * 3
return layers
def yololoss(self):
return YoloLoss()
def training_step(self, batch, batch_idx):
x, y = batch
y0, y1, y2 = y[0], y[1], y[2]
out = self.forward(x)
# print(out[0].shape, y0.shape)
loss = ( # ?
self.yololoss()(out[0], y0, self.scaled_anchors[0])
+ self.yololoss()(out[1], y1, self.scaled_anchors[1])
+ self.yololoss()(out[2], y2, self.scaled_anchors[2])
)
self.log(
"train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
y0, y1, y2 = y[0], y[1], y[2]
out = self.forward(x)
loss = (
self.yololoss()(out[0], y0, self.scaled_anchors[0])
+ self.yololoss()(out[1], y1, self.scaled_anchors[1])
+ self.yololoss()(out[2], y2, self.scaled_anchors[2])
)
self.log(
"test_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
)
return loss
def on_train_epoch_end(self) -> None:
print(
f"Epoch: {self.current_epoch}, Loss: {self.trainer.callback_metrics['train_loss_epoch']}"
)
def on_test_epoch_end(self) -> None:
print(
f"Epoch: {self.current_epoch}, Loss: {self.trainer.callback_metrics['test_loss_epoch']}"
)
def configure_optimizers(self):
optimizer = optim.Adam(
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
scheduler = OneCycleLR(
optimizer,
max_lr=self.best_lr,
steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
epochs=config.NUM_EPOCHS,
pct_start=8 / config.NUM_EPOCHS,
div_factor=100,
three_phase=False,
final_div_factor=100,
anneal_strategy="linear"
)
return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
def on_train_end(self) -> None:
torch.save(self.state_dict(), config.MODEL_STATE_DICT_PATH)
if __name__ == "main":
num_classes = 20
IMAGE_SIZE = 416
model = YOLOv3(num_classes=num_classes)
x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
out = model(x)
assert model(x)[0].shape == (
2,
3,
IMAGE_SIZE // 32,
IMAGE_SIZE // 32,
num_classes + 5
)
assert model(x)[1].shape == (
2,
3,
IMAGE_SIZE // 16,
IMAGE_SIZE // 16,
num_classes + 5
)
assert model(x)[2].shape == (
2,
3,
IMAGE_SIZE // 8,
IMAGE_SIZE // 8,
num_classes + 5
)
print("Image size compatibility check passed!")