from torch import nn from torchvision import models class CustomResNetModel(nn.Module): """Custom model based on ResNet50.""" def __init__(self, num_classes=24): super(CustomResNetModel, self).__init__() self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) for param in self.model.parameters(): param.requires_grad = False self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) def forward(self, x): return self.model(x)