aagirre92's picture
initial commit
8bf4df8
raw
history blame
1.25 kB
import torch
import torchvision
from torch import nn
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
def create_effnetb2_model(num_classes:int=3, seed:int=42):
# https://pytorch.org/vision/main/models/generated/torchvision.models.efficientnet_b2.html#torchvision.models.efficientnet_b2
def get_state_dict(self, *args, **kwargs):
kwargs.pop("check_hash")
return load_state_dict_from_url(self.url, *args, **kwargs)
WeightsEnum.get_state_dict = get_state_dict
# 1. Setup pretrained EffNetB2 weights
effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT # DEFAULT = BEST
# 2. Get EffNetB2 transforms
effnetb2_transforms = effnetb2_weights.transforms()
# 3. Setup pretrained model instance
effnetb2_model = torchvision.models.efficientnet_b2(weights=effnetb2_weights)
# 4. Freeze the base layers in the model
for param in effnetb2_model.features.parameters():
param.requires_grad = False
# 5. Modify the classifier
torch.manual_seed(seed)
effnetb2_model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408,out_features=num_classes,bias=True)
)
return effnetb2_model, effnetb2_transforms