from torch import nn | |
from torchvision import models | |
class CustomResNetModel(nn.Module): | |
"""Custom model based on ResNet50.""" | |
def __init__(self, num_classes=24): | |
super(CustomResNetModel, self).__init__() | |
self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) | |
def forward(self, x): | |
return self.model(x) | |