Đinh Ngọc Ân
first commit
f6a046f
raw
history blame
1.14 kB
import torch, torchvision
from torch import nn
def create_effnetb2_model(num_classes: int = 3,
seed:int=42):
"""Creates an EfficientNetB2 feature extractor model and transforms.
Args:
num_classes (int, optional): Number of output neurons in the output layer. Defaults to 3
seed (int, optional): Random seed value. Defaults to 42.
Returns:
torchvision.models.efficientnet_b2: EffNetB2 feature extractor model
"""
# 1. Setup pretrained EffNMetB2 weights
effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
effnetb2_transform = effnetb2_weights.transforms()
# 2. Setup pretrained model
effnetb2 = torchvision.models.efficientnet_b2(weights=effnetb2_weights)
# 3. Freeze the base layers
for param in effnetb2.parameters():
param.requires_grad = False
# 4. Change the classsifier to 3 classes
torch.manual_seed(seed)
effnetb2.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=num_classes))
return effnetb2, effnetb2_transform