""" Implementation of YOLOv3 architecture """ import pytorch_lightning as pl import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import OneCycleLR from . import config from .loss import YoloLoss model_config = [ (32, 3, 1), (64, 3, 2), ["B", 1], (128, 3, 2), ["B", 2], (256, 3, 2), ["B", 8], (512, 3, 2), ["B", 8], (1024, 3, 2), ["B", 4], # darknet 53 ends here (512, 1, 1), (1024, 3, 1), "S", (256, 1, 1), "U", (256, 1, 1), (512, 3, 1), "S", (128, 1, 1), "U", (128, 1, 1), (256, 3, 1), "S" ] class CNNBlock(pl.LightningModule): def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) self.bn = nn.BatchNorm2d(out_channels) self.leaky = nn.LeakyReLU(0.1) self.use_bn_act = bn_act def forward(self, x): if self.use_bn_act: return self.leaky(self.bn((self.conv(x)))) else: return self.conv(x) class ResidualBlock(pl.LightningModule): def __init__(self, channels, use_residual=True, num_repeats=1): super().__init__() self.layers = nn.ModuleList() for repeat in range(num_repeats): self.layers += [ nn.Sequential( CNNBlock(channels, channels//2, kernel_size=1), CNNBlock(channels//2, channels, kernel_size=3, padding=1) ) ] self.use_residual = use_residual self.num_repeats = num_repeats def forward(self, x): for layer in self.layers: if self.use_residual: x = x + layer(x) else: x = layer(x) return x class ScalePrediction(pl.LightningModule): def __init__(self, in_channels, num_classes): super().__init__() self.pred = nn.Sequential( CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), CNNBlock(2 * in_channels, (num_classes + 5) * 3, kernel_size=1, bn_act=False) ) self.num_classes = num_classes def forward(self, x): return ( self.pred(x). reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]). permute(0, 1, 3, 4, 2) ) class YOLOv3(pl.LightningModule): def __init__(self, in_channels=3, num_classes=20): super().__init__() self.num_classes = num_classes self.in_channels = in_channels self.layers = self._create_conv_layers() self.scaled_anchors = ( torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) # ? ).to(config.DEVICE) self.learning_rate = config.LEARNING_RATE self.weight_decay = config.WEIGHT_DECAY self.best_lr = 1e-3 ## ? def forward(self, x): # ? outputs = [] # for each scale route_connections = [] for layer in self.layers: if isinstance(layer, ScalePrediction): outputs.append(layer(x)) continue x = layer(x) if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: route_connections.append(x) elif isinstance(layer, nn.Upsample): x = torch.cat([x, route_connections[-1]], dim=1) route_connections.pop() return outputs def _create_conv_layers(self): layers = nn.ModuleList() in_channels = self.in_channels for module in model_config: if isinstance(module, tuple): out_channels, kernel_size, stride = module layers.append( CNNBlock(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1 if kernel_size==3 else 0) ) in_channels = out_channels elif isinstance(module, list): num_repeats = module[1] layers.append( ResidualBlock(in_channels, num_repeats=num_repeats) ) elif isinstance(module, str): if module == "S": layers += [ ResidualBlock(in_channels, use_residual=False, num_repeats=1), CNNBlock(in_channels, in_channels//2, kernel_size=1), ScalePrediction(in_channels//2, num_classes=self.num_classes) ] in_channels = in_channels // 2 elif module == "U": layers.append(nn.Upsample(scale_factor=2)) in_channels = in_channels * 3 return layers def yololoss(self): return YoloLoss() def training_step(self, batch, batch_idx): x, y = batch y0, y1, y2 = y[0], y[1], y[2] out = self.forward(x) # print(out[0].shape, y0.shape) loss = ( # ? self.yololoss()(out[0], y0, self.scaled_anchors[0]) + self.yololoss()(out[1], y1, self.scaled_anchors[1]) + self.yololoss()(out[2], y2, self.scaled_anchors[2]) ) self.log( "train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True ) return loss def test_step(self, batch, batch_idx): x, y = batch y0, y1, y2 = y[0], y[1], y[2] out = self.forward(x) loss = ( self.yololoss()(out[0], y0, self.scaled_anchors[0]) + self.yololoss()(out[1], y1, self.scaled_anchors[1]) + self.yololoss()(out[2], y2, self.scaled_anchors[2]) ) self.log( "test_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True ) return loss def on_train_epoch_end(self) -> None: print( f"Epoch: {self.current_epoch}, Loss: {self.trainer.callback_metrics['train_loss_epoch']}" ) def on_test_epoch_end(self) -> None: print( f"Epoch: {self.current_epoch}, Loss: {self.trainer.callback_metrics['test_loss_epoch']}" ) def configure_optimizers(self): optimizer = optim.Adam( self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) scheduler = OneCycleLR( optimizer, max_lr=self.best_lr, steps_per_epoch=len(self.trainer.datamodule.train_dataloader()), epochs=config.NUM_EPOCHS, pct_start=8 / config.NUM_EPOCHS, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy="linear" ) return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] def on_train_end(self) -> None: torch.save(self.state_dict(), config.MODEL_STATE_DICT_PATH) if __name__ == "main": num_classes = 20 IMAGE_SIZE = 416 model = YOLOv3(num_classes=num_classes) x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE)) out = model(x) assert model(x)[0].shape == ( 2, 3, IMAGE_SIZE // 32, IMAGE_SIZE // 32, num_classes + 5 ) assert model(x)[1].shape == ( 2, 3, IMAGE_SIZE // 16, IMAGE_SIZE // 16, num_classes + 5 ) assert model(x)[2].shape == ( 2, 3, IMAGE_SIZE // 8, IMAGE_SIZE // 8, num_classes + 5 ) print("Image size compatibility check passed!")