Irgenija's picture
Upload 3 files
b796ff0 verified
raw
history blame contribute delete
523 Bytes
from torch import nn
from torchvision import models
class CustomResNetModel(nn.Module):
"""Custom model based on ResNet50."""
def __init__(self, num_classes=24):
super(CustomResNetModel, self).__init__()
self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
for param in self.model.parameters():
param.requires_grad = False
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def forward(self, x):
return self.model(x)